py-libp2p/libp2p/crypto/authenticated_encryption.py

129 lines
3.6 KiB
Python
Raw Normal View History

from dataclasses import dataclass
import hmac
from typing import Tuple
from Crypto.Cipher import AES
2019-08-24 19:57:56 +02:00
import Crypto.Util.Counter as Counter
class InvalidMACException(Exception):
pass
@dataclass(frozen=True)
class EncryptionParameters:
cipher_type: str
hash_type: str
iv: bytes
mac_key: bytes
cipher_key: bytes
class MacAndCipher:
def __init__(self, parameters: EncryptionParameters) -> None:
self.authenticator = hmac.new(
parameters.mac_key, digestmod=parameters.hash_type
)
2019-08-24 19:57:56 +02:00
iv_bit_size = 8 * len(parameters.iv)
cipher = AES.new(
2019-08-24 19:57:56 +02:00
parameters.cipher_key,
AES.MODE_CTR,
counter=Counter.new(
iv_bit_size,
initial_value=int.from_bytes(parameters.iv, byteorder="big"),
),
)
self.cipher = cipher
def encrypt(self, data: bytes) -> bytes:
return self.cipher.encrypt(data)
def authenticate(self, data: bytes) -> bytes:
authenticator = self.authenticator.copy()
authenticator.update(data)
return authenticator.digest()
def decrypt_if_valid(self, data_with_tag: bytes) -> bytes:
tag_position = len(data_with_tag) - self.authenticator.digest_size
data = data_with_tag[:tag_position]
tag = data_with_tag[tag_position:]
authenticator = self.authenticator.copy()
authenticator.update(data)
expected_tag = authenticator.digest()
if not hmac.compare_digest(tag, expected_tag):
raise InvalidMACException(expected_tag, tag)
return self.cipher.decrypt(data)
def initialize_pair(
cipher_type: str, hash_type: str, secret: bytes
) -> Tuple[EncryptionParameters, EncryptionParameters]:
"""
Return a pair of ``Keys`` for use in securing a
communications channel with authenticated encryption
derived from the ``secret`` and using the
requested ``cipher_type`` and ``hash_type``.
"""
if cipher_type != "AES-128":
raise NotImplementedError()
if hash_type != "SHA256":
raise NotImplementedError()
iv_size = 16
cipher_key_size = 16
hmac_key_size = 20
seed = "key expansion".encode()
2019-08-24 19:57:56 +02:00
params_size = iv_size + cipher_key_size + hmac_key_size
result = bytearray(2 * params_size)
authenticator = hmac.new(secret, digestmod=hash_type)
authenticator.update(seed)
tag = authenticator.digest()
i = 0
2019-08-24 19:57:56 +02:00
len_result = 2 * params_size
while i < len_result:
authenticator = hmac.new(secret, digestmod=hash_type)
authenticator.update(tag)
authenticator.update(seed)
another_tag = authenticator.digest()
remaining_bytes = len(another_tag)
2019-08-24 19:57:56 +02:00
if i + remaining_bytes > len_result:
remaining_bytes = len_result - i
2019-08-24 19:57:56 +02:00
result[i : i + remaining_bytes] = another_tag[0:remaining_bytes]
i += remaining_bytes
authenticator = hmac.new(secret, digestmod=hash_type)
authenticator.update(tag)
tag = authenticator.digest()
2019-08-24 19:57:56 +02:00
first_half = result[:params_size]
second_half = result[params_size:]
return (
EncryptionParameters(
cipher_type,
hash_type,
first_half[0:iv_size],
first_half[iv_size + cipher_key_size :],
2019-08-24 19:57:56 +02:00
first_half[iv_size : iv_size + cipher_key_size],
),
EncryptionParameters(
cipher_type,
hash_type,
second_half[0:iv_size],
second_half[iv_size + cipher_key_size :],
2019-08-24 19:57:56 +02:00
second_half[iv_size : iv_size + cipher_key_size],
),
)