from dataclasses import dataclass import io import itertools from typing import Optional, Tuple import multihash from libp2p.crypto.authenticated_encryption import ( EncryptionParameters as AuthenticatedEncryptionParameters, ) from libp2p.crypto.authenticated_encryption import ( initialize_pair as initialize_pair_for_encryption, ) from libp2p.crypto.authenticated_encryption import InvalidMACException from libp2p.crypto.authenticated_encryption import MacAndCipher as Encrypter from libp2p.crypto.ecc import ECCPublicKey from libp2p.crypto.exceptions import MissingDeserializerError from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key from libp2p.io.exceptions import DecryptionFailedException, IOException from libp2p.io.msgio import MsgIOReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID from libp2p.security.base_session import BaseSession from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn from libp2p.typing import TProtocol from .exceptions import ( IncompatibleChoices, InconsistentNonce, InvalidSignatureOnExchange, PeerMismatchException, SecioException, SedesException, SelfEncryption, ) from .pb.spipe_pb2 import Exchange, Propose ID = TProtocol("/secio/1.0.0") NONCE_SIZE = 16 # bytes # NOTE: the following is only a subset of allowable parameters according to the # `secio` specification. DEFAULT_SUPPORTED_EXCHANGES = "P-256" DEFAULT_SUPPORTED_CIPHERS = "AES-128" DEFAULT_SUPPORTED_HASHES = "SHA256" class SecureSession(BaseSession): buf: io.BytesIO low_watermark: int high_watermark: int def __init__( self, local_peer: PeerID, local_private_key: PrivateKey, local_encryption_parameters: AuthenticatedEncryptionParameters, remote_peer: PeerID, remote_encryption_parameters: AuthenticatedEncryptionParameters, conn: MsgIOReadWriter, is_initiator: bool, ) -> None: super().__init__(local_peer, local_private_key, is_initiator, remote_peer) self.conn = conn self.local_encryption_parameters = local_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_remote_peer() self._reset_internal_buffer() def _initialize_authenticated_encryption_for_local_peer(self) -> None: self.local_encrypter = Encrypter(self.local_encryption_parameters) def _initialize_authenticated_encryption_for_remote_peer(self) -> None: self.remote_encrypter = Encrypter(self.remote_encryption_parameters) async def next_msg_len(self) -> int: return await self.conn.next_msg_len() def _reset_internal_buffer(self) -> None: self.buf = io.BytesIO() self.low_watermark = 0 self.high_watermark = 0 def _drain(self, n: int) -> bytes: if self.low_watermark == self.high_watermark: return bytes() data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] if n is None: n = len(data) result = data[:n].tobytes() self.low_watermark += len(result) if self.low_watermark == self.high_watermark: del data # free the memoryview so we can free the underlying BytesIO self.buf.close() self._reset_internal_buffer() return result async def _fill(self) -> None: msg = await self.read_msg() self.buf.write(msg) self.low_watermark = 0 self.high_watermark = len(msg) async def read(self, n: int = None) -> bytes: if n == 0: return bytes() data_from_buffer = self._drain(n) if len(data_from_buffer) > 0: return data_from_buffer next_length = await self.next_msg_len() if n < next_length: await self._fill() return self._drain(n) else: return await self.read_msg() async def read_msg(self) -> bytes: msg = await self.conn.read_msg() try: decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg) except InvalidMACException as e: raise DecryptionFailedException() from e return decrypted_msg async def write(self, data: bytes) -> int: await self.write_msg(data) return len(data) async def write_msg(self, msg: bytes) -> None: encrypted_data = self.local_encrypter.encrypt(msg) tag = self.local_encrypter.authenticate(encrypted_data) await self.conn.write_msg(encrypted_data + tag) @dataclass(frozen=True) class Proposal: """A ``Proposal`` represents the set of session parameters one peer in a pair of peers attempting to negotiate a `secio` channel prefers.""" nonce: bytes public_key: PublicKey exchanges: str = DEFAULT_SUPPORTED_EXCHANGES # comma separated list ciphers: str = DEFAULT_SUPPORTED_CIPHERS # comma separated list hashes: str = DEFAULT_SUPPORTED_HASHES # comma separated list def serialize(self) -> bytes: protobuf = Propose( rand=self.nonce, public_key=self.public_key.serialize(), exchanges=self.exchanges, ciphers=self.ciphers, hashes=self.hashes, ) return protobuf.SerializeToString() @classmethod def deserialize(cls, protobuf_bytes: bytes) -> "Proposal": protobuf = Propose.FromString(protobuf_bytes) nonce = protobuf.rand public_key_protobuf_bytes = protobuf.public_key try: public_key = deserialize_public_key(public_key_protobuf_bytes) except MissingDeserializerError as error: raise SedesException() from error exchanges = protobuf.exchanges ciphers = protobuf.ciphers hashes = protobuf.hashes return cls(nonce, public_key, exchanges, ciphers, hashes) def calculate_peer_id(self) -> PeerID: return PeerID.from_pubkey(self.public_key) @dataclass class EncryptionParameters: permanent_public_key: PublicKey curve_type: str cipher_type: str hash_type: str ephemeral_public_key: PublicKey def __init__(self) -> None: pass @dataclass class SessionParameters: local_peer: PeerID local_encryption_parameters: EncryptionParameters remote_peer: PeerID remote_encryption_parameters: EncryptionParameters # order is a comparator used to break the symmetry b/t each pair of peers order: int shared_key: bytes def __init__(self) -> None: pass async def _response_to_msg(read_writer: MsgIOReadWriter, msg: bytes) -> bytes: await read_writer.write_msg(msg) return await read_writer.read_msg() def _mk_multihash_sha256(data: bytes) -> bytes: return multihash.digest(data, "sha2-256") def _mk_score(public_key: PublicKey, nonce: bytes) -> bytes: return _mk_multihash_sha256(public_key.serialize() + nonce) def _select_parameter_from_order( order: int, supported_parameters: str, available_parameters: str ) -> str: if order < 0: first_choices = available_parameters.split(",") second_choices = supported_parameters.split(",") elif order > 0: first_choices = supported_parameters.split(",") second_choices = available_parameters.split(",") else: return supported_parameters.split(",")[0] for first, second in itertools.product(first_choices, second_choices): if first == second: return first raise IncompatibleChoices() def _select_encryption_parameters( local_proposal: Proposal, remote_proposal: Proposal ) -> Tuple[str, str, str, int]: first_score = _mk_score(remote_proposal.public_key, local_proposal.nonce) second_score = _mk_score(local_proposal.public_key, remote_proposal.nonce) order = 0 if first_score < second_score: order = -1 elif second_score < first_score: order = 1 if order == 0: raise SelfEncryption() return ( _select_parameter_from_order( order, DEFAULT_SUPPORTED_EXCHANGES, remote_proposal.exchanges ), _select_parameter_from_order( order, DEFAULT_SUPPORTED_CIPHERS, remote_proposal.ciphers ), _select_parameter_from_order( order, DEFAULT_SUPPORTED_HASHES, remote_proposal.hashes ), order, ) async def _establish_session_parameters( local_peer: PeerID, local_private_key: PrivateKey, remote_peer: Optional[PeerID], conn: MsgIOReadWriter, nonce: bytes, ) -> Tuple[SessionParameters, bytes]: # establish shared encryption parameters session_parameters = SessionParameters() session_parameters.local_peer = local_peer local_encryption_parameters = EncryptionParameters() session_parameters.local_encryption_parameters = local_encryption_parameters local_public_key = local_private_key.get_public_key() local_encryption_parameters.permanent_public_key = local_public_key local_proposal = Proposal(nonce, local_public_key) serialized_local_proposal = local_proposal.serialize() serialized_remote_proposal = await _response_to_msg(conn, serialized_local_proposal) remote_encryption_parameters = EncryptionParameters() session_parameters.remote_encryption_parameters = remote_encryption_parameters remote_proposal = Proposal.deserialize(serialized_remote_proposal) remote_encryption_parameters.permanent_public_key = remote_proposal.public_key remote_peer_from_proposal = remote_proposal.calculate_peer_id() if not remote_peer: remote_peer = remote_peer_from_proposal elif remote_peer != remote_peer_from_proposal: raise PeerMismatchException( { "expected_remote_peer": remote_peer, "received_remote_peer": remote_peer_from_proposal, } ) session_parameters.remote_peer = remote_peer curve_param, cipher_param, hash_param, order = _select_encryption_parameters( local_proposal, remote_proposal ) local_encryption_parameters.curve_type = curve_param local_encryption_parameters.cipher_type = cipher_param local_encryption_parameters.hash_type = hash_param remote_encryption_parameters.curve_type = curve_param remote_encryption_parameters.cipher_type = cipher_param remote_encryption_parameters.hash_type = hash_param session_parameters.order = order # exchange ephemeral pub keys local_ephemeral_public_key, shared_key_generator = create_ephemeral_key_pair( curve_param ) local_encryption_parameters.ephemeral_public_key = local_ephemeral_public_key local_selection = ( serialized_local_proposal + serialized_remote_proposal + local_ephemeral_public_key.to_bytes() ) exchange_signature = local_private_key.sign(local_selection) local_exchange = Exchange( ephemeral_public_key=local_ephemeral_public_key.to_bytes(), signature=exchange_signature, ) serialized_local_exchange = local_exchange.SerializeToString() serialized_remote_exchange = await _response_to_msg(conn, serialized_local_exchange) remote_exchange = Exchange() remote_exchange.ParseFromString(serialized_remote_exchange) remote_ephemeral_public_key_bytes = remote_exchange.ephemeral_public_key remote_ephemeral_public_key = ECCPublicKey.from_bytes( remote_ephemeral_public_key_bytes, curve_param ) remote_encryption_parameters.ephemeral_public_key = remote_ephemeral_public_key remote_selection = ( serialized_remote_proposal + serialized_local_proposal + remote_ephemeral_public_key_bytes ) valid_signature = remote_encryption_parameters.permanent_public_key.verify( remote_selection, remote_exchange.signature ) if not valid_signature: raise InvalidSignatureOnExchange() shared_key = shared_key_generator(remote_ephemeral_public_key_bytes) session_parameters.shared_key = shared_key return session_parameters, remote_proposal.nonce def _mk_session_from( local_private_key: PrivateKey, session_parameters: SessionParameters, conn: MsgIOReadWriter, is_initiator: bool, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( session_parameters.local_encryption_parameters.cipher_type, session_parameters.local_encryption_parameters.hash_type, session_parameters.shared_key, ) if session_parameters.order < 0: key_set1, key_set2 = key_set2, key_set1 session = SecureSession( session_parameters.local_peer, local_private_key, key_set1, session_parameters.remote_peer, key_set2, conn, is_initiator, ) return session async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes: await session.write_msg(remote_nonce) return await session.read_msg() async def create_secure_session( local_nonce: bytes, local_peer: PeerID, local_private_key: PrivateKey, conn: IRawConnection, remote_peer: PeerID = None, ) -> ISecureConn: """ Attempt the initial `secio` handshake with the remote peer. If successful, return an object that provides secure communication to the ``remote_peer``. Raise `SecioException` when `conn` closed. Raise `InconsistentNonce` when handshake failed """ msg_io = MsgIOReadWriter(conn) try: session_parameters, remote_nonce = await _establish_session_parameters( local_peer, local_private_key, remote_peer, msg_io, local_nonce ) except SecioException as e: await conn.close() raise e # `IOException` includes errors raised while read from/write to raw connection except IOException as e: raise SecioException("connection closed") from e is_initiator = remote_peer is not None session = _mk_session_from( local_private_key, session_parameters, msg_io, is_initiator ) try: received_nonce = await _finish_handshake(session, remote_nonce) # `IOException` includes errors raised while read from/write to raw connection except IOException as e: raise SecioException("connection closed") from e if received_nonce != local_nonce: await conn.close() raise InconsistentNonce() return session class Transport(BaseSecureTransport): """``Transport`` provides a security upgrader for a ``IRawConnection``, following the `secio` protocol defined in the libp2p specs.""" def get_nonce(self) -> bytes: return self.secure_bytes_provider(NONCE_SIZE) async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: """ Secure the connection, either locally or by communicating with opposing node via conn, for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ local_nonce = self.get_nonce() local_peer = self.local_peer local_private_key = self.local_private_key return await create_secure_session( local_nonce, local_peer, local_private_key, conn ) async def secure_outbound( self, conn: IRawConnection, peer_id: PeerID ) -> ISecureConn: """ Secure the connection, either locally or by communicating with opposing node via conn, for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ local_nonce = self.get_nonce() local_peer = self.local_peer local_private_key = self.local_private_key return await create_secure_session( local_nonce, local_peer, local_private_key, conn, peer_id )