from dataclasses import dataclass
import hmac
from typing import Tuple

from Crypto.Cipher import AES
import Crypto.Util.Counter as Counter


class InvalidMACException(Exception):
    pass


@dataclass(frozen=True)
class EncryptionParameters:
    cipher_type: str
    hash_type: str
    iv: bytes
    mac_key: bytes
    cipher_key: bytes


class MacAndCipher:
    def __init__(self, parameters: EncryptionParameters) -> None:
        self.authenticator = hmac.new(
            parameters.mac_key, digestmod=parameters.hash_type
        )
        iv_bit_size = 8 * len(parameters.iv)
        cipher = AES.new(
            parameters.cipher_key,
            AES.MODE_CTR,
            counter=Counter.new(
                iv_bit_size,
                initial_value=int.from_bytes(parameters.iv, byteorder="big"),
            ),
        )
        self.cipher = cipher

    def encrypt(self, data: bytes) -> bytes:
        return self.cipher.encrypt(data)

    def authenticate(self, data: bytes) -> bytes:
        authenticator = self.authenticator.copy()
        authenticator.update(data)
        return authenticator.digest()

    def decrypt_if_valid(self, data_with_tag: bytes) -> bytes:
        tag_position = len(data_with_tag) - self.authenticator.digest_size
        data = data_with_tag[:tag_position]
        tag = data_with_tag[tag_position:]

        authenticator = self.authenticator.copy()
        authenticator.update(data)
        expected_tag = authenticator.digest()

        if not hmac.compare_digest(tag, expected_tag):
            raise InvalidMACException(expected_tag, tag)

        return self.cipher.decrypt(data)


def initialize_pair(
    cipher_type: str, hash_type: str, secret: bytes
) -> Tuple[EncryptionParameters, EncryptionParameters]:
    """Return a pair of ``Keys`` for use in securing a communications channel
    with authenticated encryption derived from the ``secret`` and using the
    requested ``cipher_type`` and ``hash_type``."""
    if cipher_type != "AES-128":
        raise NotImplementedError()
    if hash_type != "SHA256":
        raise NotImplementedError()

    iv_size = 16
    cipher_key_size = 16
    hmac_key_size = 20
    seed = "key expansion".encode()

    params_size = iv_size + cipher_key_size + hmac_key_size
    result = bytearray(2 * params_size)

    authenticator = hmac.new(secret, digestmod=hash_type)
    authenticator.update(seed)
    tag = authenticator.digest()

    i = 0
    len_result = 2 * params_size
    while i < len_result:
        authenticator = hmac.new(secret, digestmod=hash_type)

        authenticator.update(tag)
        authenticator.update(seed)

        another_tag = authenticator.digest()

        remaining_bytes = len(another_tag)

        if i + remaining_bytes > len_result:
            remaining_bytes = len_result - i

        result[i : i + remaining_bytes] = another_tag[0:remaining_bytes]

        i += remaining_bytes

        authenticator = hmac.new(secret, digestmod=hash_type)
        authenticator.update(tag)
        tag = authenticator.digest()

    first_half = result[:params_size]
    second_half = result[params_size:]

    return (
        EncryptionParameters(
            cipher_type,
            hash_type,
            first_half[0:iv_size],
            first_half[iv_size + cipher_key_size :],
            first_half[iv_size : iv_size + cipher_key_size],
        ),
        EncryptionParameters(
            cipher_type,
            hash_type,
            second_half[0:iv_size],
            second_half[iv_size + cipher_key_size :],
            second_half[iv_size : iv_size + cipher_key_size],
        ),
    )