from cbc import CBCPaddingAttack

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import os
from util import pad_message, strip_padding

class CBCContext:
  # Initializes a CBC encryption context with secret key k
  def __init__(self, k):
    self.key = k
    self.oracle_calls = 0

  # Encrypts the message m using AES-CBC mode (with randomized IV) with key k
  # (and TLS 1.0 padding)
  def encrypt(self, m):
    # Sample a random IV for encryption
    iv = os.urandom(16)

    # Encrypt the message using AES-CBC
    cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv))
    encryptor = cipher.encryptor()
  
    return iv + encryptor.update(pad_message(m)) + encryptor.finalize()

  # CBC decryption
  def decrypt(self, ct):
    iv = ct[:16]
    payload = ct[16:]

    cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv))
    decryptor = cipher.decryptor()

    padded_msg = decryptor.update(payload) + decryptor.finalize()

    return strip_padding(padded_msg)

  # Checks if ciphertext has a valid TLS 1.0 padding (i.e., the last k bytes of the decrypted
  # message all have value k - 1)
  def padding_oracle(self, ct):
    self.oracle_calls += 1
    try:
      return self.decrypt(ct) is not None
    except:
      return False

# Sample a random 128-bit AES key
key = os.urandom(16)
msg = b"CS 346"

context = CBCContext(key)
ct = context.encrypt(msg)

decrypted_msg = CBCPaddingAttack().decrypt(ct, context.padding_oracle)

print('Plaintext:', msg)
print('Decrypted output:', decrypted_msg)
print('Successful decryption?', msg == decrypted_msg)
print('Number of padding oracle queries:', context.oracle_calls)

