diff --git a/libp2p/security/secio/exceptions.py b/libp2p/security/secio/exceptions.py new file mode 100644 index 0000000..a5f7464 --- /dev/null +++ b/libp2p/security/secio/exceptions.py @@ -0,0 +1,14 @@ +class SecioException(Exception): + pass + + +class PeerMismatchException(SecioException): + pass + + +class InvalidSignatureOnExchange(SecioException): + pass + + +class HandshakeFailed(SecioException): + pass diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index c0c7745..8ea2798 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -1,6 +1,18 @@ from dataclasses import dataclass +import hashlib from typing import Optional, Tuple +import multihash + +from libp2p.crypto.authenticated_encryption import ( + EncryptionParameters as AuthenticatedEncryptionParameters, +) +from libp2p.crypto.authenticated_encryption import ( + initialize_pair as initialize_pair_for_encryption, +) +from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter +from libp2p.crypto.ecc import ECCPublicKey +from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.io.msgio import encode as encode_message from libp2p.io.msgio import read_next_message @@ -10,6 +22,12 @@ from libp2p.security.base_session import BaseSession from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn +from .exceptions import ( + HandshakeFailed, + InvalidSignatureOnExchange, + PeerMismatchException, + SecioException, +) from .pb.spipe_pb2 import Exchange, Propose ID = "/secio/1.0.0" @@ -23,11 +41,45 @@ DEFAULT_SUPPORTED_CIPHERS = "AES-128" DEFAULT_SUPPORTED_HASHES = "SHA256" +@dataclass class SecureSession(BaseSession): local_peer: PeerID + local_encryption_parameters: AuthenticatedEncryptionParameters + remote_peer: PeerID - # specialize read and write - pass + remote_encryption_parameters: AuthenticatedEncryptionParameters + + conn: IRawConnection + + def __post_init__(self): + self._initialize_authenticated_encryption_for_local_peer() + self._initialize_authenticated_encryption_for_remote_peer() + + def _initialize_authenticated_encryption_for_local_peer(self) -> None: + self.local_encrypter = Encrypter(self.local_encryption_parameters) + + def _initialize_authenticated_encryption_for_remote_peer(self) -> None: + self.remote_encrypter = Encrypter(self.remote_encryption_parameters) + + async def read(self) -> bytes: + return await self._read_msg() + + async def _read_msg(self) -> bytes: + # TODO do we need to serialize reads? + msg = await read_next_message(self.conn) + return self.remote_encrypter.decrypt_if_valid(msg) + + async def write(self, data: bytes) -> None: + await self._write_msg(data) + + async def _write_msg(self, data: bytes) -> None: + # TODO do we need to serialize writes? + encrypted_data = self.local_encrypter.encrypt(data) + tag = self.local_encrypter.authenticate(encrypted_data) + msg = encode_message(encrypted_data + tag) + # TODO clean up how we write messages + self.conn.writer.write(msg) + await self.conn.writer.drain() @dataclass(frozen=True) @@ -81,9 +133,19 @@ class EncryptionParameters: hash_type: str ephemeral_public_key: PublicKey - keys: ... - cipher: ... - mac: ... + + +@dataclass +class SessionParameters: + local_peer: PeerID + local_encryption_parameters: EncryptionParameters + + remote_peer: PeerID + remote_encryption_parameters: EncryptionParameters + + # order is a comparator used to break the symmetry b/t each pair of peers + order: int + shared_key: bytes async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: @@ -95,16 +157,9 @@ async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: return await read_next_message(conn.reader) -@dataclass -class SessionParameters: - local_peer: PeerID - local_encryption_parameters: EncryptionParameters - remote_peer: PeerID - remote_encryption_parameters: EncryptionParameters - - def _mk_multihash_sha256(data: bytes) -> bytes: - pass + digest = hashlib.sha256(data).digest() + return multihash.encode(digest, "sha2-256") def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: @@ -130,7 +185,7 @@ def _select_parameter_from_order( def _select_encryption_parameters( local_proposal: Proposal, remote_proposal: Proposal -) -> Tuple[str, str, str]: +) -> Tuple[str, str, str, int]: first_score = _mk_score(remote_proposal.public_key, local_proposal.nonce) second_score = _mk_score(local_proposal.public_key, remote_proposal.nonce) @@ -148,12 +203,13 @@ def _select_encryption_parameters( _select_parameter_from_order( order, DEFAULT_SUPPORTED_EXCHANGES, remote_proposal.exchanges ), - _select_encryption_parameters( + _select_parameter_from_order( order, DEFAULT_SUPPORTED_CIPHERS, remote_proposal.ciphers ), - _select_encryption_parameters( + _select_parameter_from_order( order, DEFAULT_SUPPORTED_HASHES, remote_proposal.hashes ), + order, ) @@ -163,7 +219,8 @@ async def _establish_session_parameters( remote_peer: Optional[PeerID], conn: IRawConnection, nonce: bytes, -) -> SessionParameters: +) -> Tuple[SessionParameters, bytes]: + # establish shared encryption parameters session_parameters = SessionParameters() session_parameters.local_peer = local_peer @@ -189,7 +246,7 @@ async def _establish_session_parameters( raise PeerMismatchException() session_parameters.remote_peer = remote_peer - curve_param, cipher_param, hash_param = _select_encryption_parameters( + curve_param, cipher_param, hash_param, order = _select_encryption_parameters( local_proposal, remote_proposal ) local_encryption_parameters.curve_type = curve_param @@ -198,45 +255,77 @@ async def _establish_session_parameters( remote_encryption_parameters.curve_type = curve_param remote_encryption_parameters.cipher_type = cipher_param remote_encryption_parameters.hash_type = hash_param + session_parameters.order = order # exchange ephemeral pub keys - local_ephemeral_key_pair, shared_key_generator = create_elliptic_key_pair( - encryption_parameters + local_ephemeral_public_key, shared_key_generator = create_ephemeral_key_pair( + curve_param ) - local_selection = _mk_serialized_selection( - local_proposal, remote_proposal, local_ephemeral_key_pair.public_key + local_encryption_parameters.ephemeral_public_key = local_ephemeral_public_key + local_selection = ( + serialized_local_proposal + + serialized_remote_proposal + + local_ephemeral_public_key.to_bytes() ) - serialized_local_selection = _mk_serialized_selection(local_selection) - - local_exchange = _mk_exchange( - local_ephemeral_key_pair.public_key, serialized_local_selection + exchange_signature = local_private_key.sign(local_selection) + local_exchange = Exchange( + ephemeral_public_key=local_ephemeral_public_key.to_bytes(), + signature=exchange_signature, ) - serialized_local_exchange = _mk_serialized_exchange_msg(local_exchange) - serialized_remote_exchange = await _response_to_msg(serialized_local_exchange) - remote_exchange = _parse_exchange(serialized_remote_exchange) + serialized_local_exchange = local_exchange.SerializeToString() + serialized_remote_exchange = await _response_to_msg(conn, serialized_local_exchange) - remote_selection = _mk_remote_selection( - remote_exchange, local_proposal, remote_proposal + remote_exchange = Exchange() + remote_exchange.ParseFromString(serialized_remote_exchange) + + remote_ephemeral_public_key_bytes = remote_exchange.ephemeral_public_key + remote_ephemeral_public_key = ECCPublicKey.from_bytes( + remote_ephemeral_public_key_bytes ) - verify_exchange(remote_exchange, remote_selection, remote_proposal) + remote_encryption_parameters.ephemeral_public_key = remote_ephemeral_public_key + remote_selection = ( + serialized_remote_proposal + + serialized_local_proposal + + remote_ephemeral_public_key_bytes + ) + valid_signature = remote_encryption_parameters.permanent_public_key.verify( + remote_selection, remote_exchange.signature + ) + if not valid_signature: + raise InvalidSignatureOnExchange() - # return all the data we need + shared_key = shared_key_generator(remote_ephemeral_public_key_bytes) + session_parameters.shared_key = shared_key + + return session_parameters, remote_proposal.nonce -def _mk_session_from(session_parameters): - # use ephemeral pubkey to make a shared key - # stretch shared key to get two keys - # decide which side has which key - # set up mac and cipher, based on shared key, for each side - # make new rdr/wtr pairs using each mac/cipher gadget - pass +def _mk_session_from( + session_parameters: SessionParameters, conn: IRawConnection +) -> SecureSession: + key_set1, key_set2 = initialize_pair_for_encryption( + session_parameters.local_encryption_parameters.cipher_type, + session_parameters.local_encryption_parameters.hash_type, + session_parameters.shared_key, + ) + + if session_parameters.order < 0: + key_set1, key_set2 = key_set2, key_set1 + + session = SecureSession( + session_parameters.local_peer, + key_set1, + session_parameters.remote_peer, + key_set2, + conn, + ) + return session -async def _close_handshake(session): - # send nonce over encrypted channel - # verify we get our nonce back - pass +async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes: + await session.write(remote_nonce) + return await session.read() async def create_secure_session( @@ -247,21 +336,24 @@ async def create_secure_session( If successful, return an object that provides secure communication to the ``remote_peer``. """ - nonce = transport.get_nonce() + local_nonce = transport.get_nonce() local_peer = transport.local_peer local_private_key = transport.local_private_key try: - session_parameters = await _establish_session_parameters( - local_peer, local_private_key, remote_peer, conn, nonce + session_parameters, remote_nonce = await _establish_session_parameters( + local_peer, local_private_key, remote_peer, conn, local_nonce ) - except PeerMismatchException as e: + except SecioException as e: conn.close() raise e - session = _mk_session_from(session_parameters) + session = _mk_session_from(session_parameters, conn) - await _close_handshake(session) + received_nonce = await _finish_handshake(session, remote_nonce) + if received_nonce != local_nonce: + conn.close() + raise HandshakeFailed() return session