From 1adef05e941644e1baa2489c957af6917afd4022 Mon Sep 17 00:00:00 2001 From: Alex Stokes Date: Fri, 23 Aug 2019 23:43:36 +0200 Subject: [PATCH] Typing and linting fixes --- libp2p/crypto/authenticated_encryption.py | 2 +- libp2p/crypto/ecc.py | 8 +++-- libp2p/crypto/key_exchange.py | 7 ++-- libp2p/crypto/keys.py | 14 ++++---- libp2p/io/msgio.py | 2 +- libp2p/security/base_session.py | 8 ++--- libp2p/security/secio/exceptions.py | 4 +++ libp2p/security/secio/transport.py | 41 ++++++++++++++++------- 8 files changed, 53 insertions(+), 33 deletions(-) diff --git a/libp2p/crypto/authenticated_encryption.py b/libp2p/crypto/authenticated_encryption.py index f84ecb7..733d4b5 100644 --- a/libp2p/crypto/authenticated_encryption.py +++ b/libp2p/crypto/authenticated_encryption.py @@ -98,7 +98,7 @@ def initialize_pair( authenticator.update(tag) tag = authenticator.digest() - half = len(result) / 2 + half = int(len(result) / 2) first_half = result[:half] second_half = result[half:] diff --git a/libp2p/crypto/ecc.py b/libp2p/crypto/ecc.py index 7cfb433..f4d4c54 100644 --- a/libp2p/crypto/ecc.py +++ b/libp2p/crypto/ecc.py @@ -1,3 +1,5 @@ +from typing import cast + from Crypto.PublicKey import ECC from Crypto.PublicKey.ECC import EccKey @@ -9,7 +11,7 @@ class ECCPublicKey(PublicKey): self.impl = impl def to_bytes(self) -> bytes: - return self.impl.export_key("DER") + return cast(bytes, self.impl.export_key(format="DER")) @classmethod def from_bytes(cls, data: bytes) -> "ECCPublicKey": @@ -33,7 +35,7 @@ class ECCPrivateKey(PrivateKey): return cls(private_key_impl) def to_bytes(self) -> bytes: - return self.impl.export_key("DER") + return cast(bytes, self.impl.export_key(format="DER")) def get_type(self) -> KeyType: return KeyType.ECC_P256 @@ -42,7 +44,7 @@ class ECCPrivateKey(PrivateKey): raise NotImplementedError def get_public_key(self) -> PublicKey: - return ECCPublicKey(self.impl.publickey()) + return ECCPublicKey(self.impl.public_key()) def create_new_key_pair(curve: str) -> KeyPair: diff --git a/libp2p/crypto/key_exchange.py b/libp2p/crypto/key_exchange.py index a204de3..5da467f 100644 --- a/libp2p/crypto/key_exchange.py +++ b/libp2p/crypto/key_exchange.py @@ -1,8 +1,8 @@ -from typing import Callable, Tuple +from typing import Callable, Tuple, cast import Crypto.PublicKey.ECC as ECC -from libp2p.crypto.ecc import create_new_key_pair +from libp2p.crypto.ecc import ECCPrivateKey, create_new_key_pair from libp2p.crypto.keys import PublicKey SharedKeyGenerator = Callable[[bytes], bytes] @@ -20,7 +20,8 @@ def create_ephemeral_key_pair(curve_type: str) -> Tuple[PublicKey, SharedKeyGene def _key_exchange(serialized_remote_public_key: bytes) -> bytes: remote_public_key = ECC.import_key(serialized_remote_public_key) curve_point = remote_public_key.pointQ - secret_point = curve_point * key_pair.private_key.impl.d + private_key = cast(ECCPrivateKey, key_pair.private_key) + secret_point = curve_point * private_key.impl.d byte_size = secret_point.size_in_bytes() return secret_point.x.to_bytes(byte_size, byteorder="big") diff --git a/libp2p/crypto/keys.py b/libp2p/crypto/keys.py index 0647a4b..0cbdfd0 100644 --- a/libp2p/crypto/keys.py +++ b/libp2p/crypto/keys.py @@ -33,8 +33,10 @@ class Key(ABC): """ ... - def __eq__(self, other: "Key") -> bool: - return self.impl == other.impl + def __eq__(self, other: object) -> bool: + if not isinstance(other, Key): + return NotImplemented + return self.to_bytes() == other.to_bytes() class PublicKey(Key): @@ -66,9 +68,7 @@ class PublicKey(Key): @classmethod def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PublicKey: - protobuf_key = protobuf.PublicKey() - protobuf_key.ParseFromString(protobuf_data) - return protobuf_key + return protobuf.PublicKey.FromString(protobuf_data) class PrivateKey(Key): @@ -110,9 +110,7 @@ class PrivateKey(Key): @classmethod def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PrivateKey: - protobuf_key = protobuf.PrivateKey() - protobuf_key.ParseFromString(protobuf_data) - return protobuf_key + return protobuf.PrivateKey.FromString(protobuf_data) @dataclass(frozen=True) diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index 6ad11bc..f745c18 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -2,7 +2,7 @@ import asyncio SIZE_LEN_BYTES = 4 -# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 +# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 def encode(msg_bytes: bytes) -> bytes: diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index b3c8814..6f62f9a 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -3,7 +3,6 @@ from typing import Optional from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID -from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn @@ -21,12 +20,13 @@ class BaseSession(ISecureConn): def __init__( self, - transport: BaseSecureTransport, + local_peer: ID, + local_private_key: PrivateKey, conn: IRawConnection, peer_id: Optional[ID] = None, ) -> None: - self.local_peer = transport.local_peer - self.local_private_key = transport.local_private_key + self.local_peer = local_peer + self.local_private_key = local_private_key self.remote_peer_id = peer_id self.remote_permanent_pubkey = None diff --git a/libp2p/security/secio/exceptions.py b/libp2p/security/secio/exceptions.py index 1461be9..f9ea8cf 100644 --- a/libp2p/security/secio/exceptions.py +++ b/libp2p/security/secio/exceptions.py @@ -21,3 +21,7 @@ class InvalidSignatureOnExchange(SecioException): class HandshakeFailed(SecioException): pass + + +class IncompatibleChoices(SecioException): + pass diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index f77f353..03119fe 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -25,6 +25,7 @@ from libp2p.security.secure_conn_interface import ISecureConn from .exceptions import ( HandshakeFailed, + IncompatibleChoices, InvalidSignatureOnExchange, PeerMismatchException, SecioException, @@ -43,17 +44,20 @@ DEFAULT_SUPPORTED_CIPHERS = "AES-128" DEFAULT_SUPPORTED_HASHES = "SHA256" -@dataclass class SecureSession(BaseSession): - local_peer: PeerID - local_encryption_parameters: AuthenticatedEncryptionParameters + def __init__( + self, + local_peer: PeerID, + local_private_key: PrivateKey, + local_encryption_parameters: AuthenticatedEncryptionParameters, + remote_peer: PeerID, + remote_encryption_parameters: AuthenticatedEncryptionParameters, + conn: IRawConnection, + ) -> None: + super().__init__(local_peer, local_private_key, conn, remote_peer) - remote_peer: PeerID - remote_encryption_parameters: AuthenticatedEncryptionParameters - - conn: IRawConnection - - def __post_init__(self): + self.local_encryption_parameters = local_encryption_parameters + self.remote_encryption_parameters = remote_encryption_parameters self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_remote_peer() @@ -68,7 +72,8 @@ class SecureSession(BaseSession): async def _read_msg(self) -> bytes: # TODO do we need to serialize reads? - msg = await read_next_message(self.conn) + # TODO do not expose reader + msg = await read_next_message(self.conn.reader) return self.remote_encrypter.decrypt_if_valid(msg) async def write(self, data: bytes) -> None: @@ -135,6 +140,9 @@ class EncryptionParameters: ephemeral_public_key: PublicKey + def __init__(self) -> None: + pass + @dataclass class SessionParameters: @@ -148,6 +156,9 @@ class SessionParameters: order: int shared_key: bytes + def __init__(self) -> None: + pass + async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: # TODO clean up ``IRawConnection`` so that we don't have to break @@ -182,6 +193,7 @@ def _select_parameter_from_order( for first, second in zip(first_choices, second_choices): if first == second: return first + raise IncompatibleChoices() def _select_encryption_parameters( @@ -302,7 +314,9 @@ async def _establish_session_parameters( def _mk_session_from( - session_parameters: SessionParameters, conn: IRawConnection + local_private_key: PrivateKey, + session_parameters: SessionParameters, + conn: IRawConnection, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( session_parameters.local_encryption_parameters.cipher_type, @@ -315,6 +329,7 @@ def _mk_session_from( session = SecureSession( session_parameters.local_peer, + local_private_key, key_set1, session_parameters.remote_peer, key_set2, @@ -329,7 +344,7 @@ async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes: async def create_secure_session( - transport: BaseSecureTransport, conn: IRawConnection, remote_peer: PeerID = None + transport: "SecIOTransport", conn: IRawConnection, remote_peer: PeerID = None ) -> ISecureConn: """ Attempt the initial `secio` handshake with the remote peer. @@ -348,7 +363,7 @@ async def create_secure_session( conn.close() raise e - session = _mk_session_from(session_parameters, conn) + session = _mk_session_from(local_private_key, session_parameters, conn) received_nonce = await _finish_handshake(session, remote_nonce) if received_nonce != local_nonce: