Fix bugs in secio implementation

This commit is contained in:
Alex Stokes 2019-08-24 19:57:56 +02:00
parent 228032805a
commit b8c0ef9ebb
No known key found for this signature in database
GPG Key ID: 51CE1721B245C086
3 changed files with 29 additions and 23 deletions

View File

@ -3,6 +3,7 @@ import hmac
from typing import Tuple from typing import Tuple
from Crypto.Cipher import AES from Crypto.Cipher import AES
import Crypto.Util.Counter as Counter
class InvalidMACException(Exception): class InvalidMACException(Exception):
@ -23,8 +24,14 @@ class MacAndCipher:
self.authenticator = hmac.new( self.authenticator = hmac.new(
parameters.mac_key, digestmod=parameters.hash_type parameters.mac_key, digestmod=parameters.hash_type
) )
iv_bit_size = 8 * len(parameters.iv)
cipher = AES.new( cipher = AES.new(
parameters.cipher_key, AES.MODE_CTR, initial_value=parameters.iv parameters.cipher_key,
AES.MODE_CTR,
counter=Counter.new(
iv_bit_size,
initial_value=int.from_bytes(parameters.iv, byteorder="big"),
),
) )
self.cipher = cipher self.cipher = cipher
@ -70,14 +77,16 @@ def initialize_pair(
hmac_key_size = 20 hmac_key_size = 20
seed = "key expansion".encode() seed = "key expansion".encode()
result = bytearray(2 * (iv_size + cipher_key_size + hmac_key_size)) params_size = iv_size + cipher_key_size + hmac_key_size
result = bytearray(2 * params_size)
authenticator = hmac.new(secret, digestmod=hash_type) authenticator = hmac.new(secret, digestmod=hash_type)
authenticator.update(seed) authenticator.update(seed)
tag = authenticator.digest() tag = authenticator.digest()
i = 0 i = 0
while i < len(result): len_result = 2 * params_size
while i < len_result:
authenticator = hmac.new(secret, digestmod=hash_type) authenticator = hmac.new(secret, digestmod=hash_type)
authenticator.update(tag) authenticator.update(tag)
@ -87,10 +96,10 @@ def initialize_pair(
remaining_bytes = len(another_tag) remaining_bytes = len(another_tag)
if i + remaining_bytes > len(result): if i + remaining_bytes > len_result:
remaining_bytes = len(result) - i remaining_bytes = len_result - i
result[i : i + remaining_bytes] = another_tag result[i : i + remaining_bytes] = another_tag[0:remaining_bytes]
i += remaining_bytes i += remaining_bytes
@ -98,23 +107,22 @@ def initialize_pair(
authenticator.update(tag) authenticator.update(tag)
tag = authenticator.digest() tag = authenticator.digest()
half = int(len(result) / 2) first_half = result[:params_size]
first_half = result[:half] second_half = result[params_size:]
second_half = result[half:]
return ( return (
EncryptionParameters( EncryptionParameters(
cipher_type, cipher_type,
hash_type, hash_type,
first_half[0:iv_size], first_half[0:iv_size],
first_half[iv_size : iv_size + cipher_key_size],
first_half[iv_size + cipher_key_size :], first_half[iv_size + cipher_key_size :],
first_half[iv_size : iv_size + cipher_key_size],
), ),
EncryptionParameters( EncryptionParameters(
cipher_type, cipher_type,
hash_type, hash_type,
second_half[0:iv_size], second_half[0:iv_size],
second_half[iv_size : iv_size + cipher_key_size],
second_half[iv_size + cipher_key_size :], second_half[iv_size + cipher_key_size :],
second_half[iv_size : iv_size + cipher_key_size],
), ),
) )

View File

@ -23,6 +23,6 @@ def create_ephemeral_key_pair(curve_type: str) -> Tuple[PublicKey, SharedKeyGene
private_key = cast(ECCPrivateKey, key_pair.private_key) private_key = cast(ECCPrivateKey, key_pair.private_key)
secret_point = curve_point * private_key.impl.d secret_point = curve_point * private_key.impl.d
byte_size = secret_point.size_in_bytes() byte_size = secret_point.size_in_bytes()
return secret_point.x.to_bytes(byte_size, byteorder="big") return secret_point.x.to_bytes(byte_size)
return key_pair.public_key, _key_exchange return key_pair.public_key, _key_exchange

View File

@ -85,7 +85,7 @@ class SecureSession(BaseSession):
tag = self.local_encrypter.authenticate(encrypted_data) tag = self.local_encrypter.authenticate(encrypted_data)
msg = encode_message(encrypted_data + tag) msg = encode_message(encrypted_data + tag)
# TODO clean up how we write messages # TODO clean up how we write messages
self.conn.writer.write(msg) await self.conn.writer.write(msg)
await self.conn.writer.drain() await self.conn.writer.drain()
@ -104,18 +104,17 @@ class Proposal:
def serialize(self) -> bytes: def serialize(self) -> bytes:
protobuf = Propose( protobuf = Propose(
self.nonce, rand=self.nonce,
self.public_key.serialize(), public_key=self.public_key.serialize(),
self.exchanges, exchanges=self.exchanges,
self.ciphers, ciphers=self.ciphers,
self.hashes, hashes=self.hashes,
) )
return protobuf.SerializeToString() return protobuf.SerializeToString()
@classmethod @classmethod
def deserialize(cls, protobuf_bytes: bytes) -> "Proposal": def deserialize(cls, protobuf_bytes: bytes) -> "Proposal":
protobuf = Propose() protobuf = Propose.FromString(protobuf_bytes)
protobuf.ParseFromString(protobuf_bytes)
nonce = protobuf.rand nonce = protobuf.rand
public_key_protobuf_bytes = protobuf.public_key public_key_protobuf_bytes = protobuf.public_key
@ -163,15 +162,14 @@ class SessionParameters:
async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes:
# TODO clean up ``IRawConnection`` so that we don't have to break # TODO clean up ``IRawConnection`` so that we don't have to break
# the abstraction # the abstraction
conn.writer.write(encode_message(msg)) await conn.writer.write(encode_message(msg))
await conn.writer.drain() await conn.writer.drain()
return await read_next_message(conn.reader) return await read_next_message(conn.reader)
def _mk_multihash_sha256(data: bytes) -> bytes: def _mk_multihash_sha256(data: bytes) -> bytes:
digest = hashlib.sha256(data).digest() return multihash.digest(data, "sha2-256")
return multihash.encode(digest, "sha2-256")
def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: