126 lines
3.5 KiB
Python
126 lines
3.5 KiB
Python
from dataclasses import dataclass
|
|
import hmac
|
|
from typing import Tuple
|
|
|
|
from Crypto.Cipher import AES
|
|
import Crypto.Util.Counter as Counter
|
|
|
|
|
|
class InvalidMACException(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EncryptionParameters:
|
|
cipher_type: str
|
|
hash_type: str
|
|
iv: bytes
|
|
mac_key: bytes
|
|
cipher_key: bytes
|
|
|
|
|
|
class MacAndCipher:
|
|
def __init__(self, parameters: EncryptionParameters) -> None:
|
|
self.authenticator = hmac.new(
|
|
parameters.mac_key, digestmod=parameters.hash_type
|
|
)
|
|
iv_bit_size = 8 * len(parameters.iv)
|
|
cipher = AES.new(
|
|
parameters.cipher_key,
|
|
AES.MODE_CTR,
|
|
counter=Counter.new(
|
|
iv_bit_size,
|
|
initial_value=int.from_bytes(parameters.iv, byteorder="big"),
|
|
),
|
|
)
|
|
self.cipher = cipher
|
|
|
|
def encrypt(self, data: bytes) -> bytes:
|
|
return self.cipher.encrypt(data)
|
|
|
|
def authenticate(self, data: bytes) -> bytes:
|
|
authenticator = self.authenticator.copy()
|
|
authenticator.update(data)
|
|
return authenticator.digest()
|
|
|
|
def decrypt_if_valid(self, data_with_tag: bytes) -> bytes:
|
|
tag_position = len(data_with_tag) - self.authenticator.digest_size
|
|
data = data_with_tag[:tag_position]
|
|
tag = data_with_tag[tag_position:]
|
|
|
|
authenticator = self.authenticator.copy()
|
|
authenticator.update(data)
|
|
expected_tag = authenticator.digest()
|
|
|
|
if not hmac.compare_digest(tag, expected_tag):
|
|
raise InvalidMACException(expected_tag, tag)
|
|
|
|
return self.cipher.decrypt(data)
|
|
|
|
|
|
def initialize_pair(
|
|
cipher_type: str, hash_type: str, secret: bytes
|
|
) -> Tuple[EncryptionParameters, EncryptionParameters]:
|
|
"""Return a pair of ``Keys`` for use in securing a communications channel
|
|
with authenticated encryption derived from the ``secret`` and using the
|
|
requested ``cipher_type`` and ``hash_type``."""
|
|
if cipher_type != "AES-128":
|
|
raise NotImplementedError()
|
|
if hash_type != "SHA256":
|
|
raise NotImplementedError()
|
|
|
|
iv_size = 16
|
|
cipher_key_size = 16
|
|
hmac_key_size = 20
|
|
seed = "key expansion".encode()
|
|
|
|
params_size = iv_size + cipher_key_size + hmac_key_size
|
|
result = bytearray(2 * params_size)
|
|
|
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
|
authenticator.update(seed)
|
|
tag = authenticator.digest()
|
|
|
|
i = 0
|
|
len_result = 2 * params_size
|
|
while i < len_result:
|
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
|
|
|
authenticator.update(tag)
|
|
authenticator.update(seed)
|
|
|
|
another_tag = authenticator.digest()
|
|
|
|
remaining_bytes = len(another_tag)
|
|
|
|
if i + remaining_bytes > len_result:
|
|
remaining_bytes = len_result - i
|
|
|
|
result[i : i + remaining_bytes] = another_tag[0:remaining_bytes]
|
|
|
|
i += remaining_bytes
|
|
|
|
authenticator = hmac.new(secret, digestmod=hash_type)
|
|
authenticator.update(tag)
|
|
tag = authenticator.digest()
|
|
|
|
first_half = result[:params_size]
|
|
second_half = result[params_size:]
|
|
|
|
return (
|
|
EncryptionParameters(
|
|
cipher_type,
|
|
hash_type,
|
|
first_half[0:iv_size],
|
|
first_half[iv_size + cipher_key_size :],
|
|
first_half[iv_size : iv_size + cipher_key_size],
|
|
),
|
|
EncryptionParameters(
|
|
cipher_type,
|
|
hash_type,
|
|
second_half[0:iv_size],
|
|
second_half[iv_size + cipher_key_size :],
|
|
second_half[iv_size : iv_size + cipher_key_size],
|
|
),
|
|
)
|