From 3c2e835725af48ad497baeb40d0cdf28b4ebb7b1 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 17 Feb 2020 23:33:45 +0800 Subject: [PATCH] Security: `SecureSession` Make security sessions(secio, noise) share the same implementation `BaseSession` to avoid duplicate implementation of buffered read. --- libp2p/io/abc.py | 11 +- libp2p/io/msgio.py | 7 +- libp2p/security/insecure/io.py | 19 +++ libp2p/security/insecure/transport.py | 116 +++++++++--------- libp2p/security/noise/io.py | 17 ++- libp2p/security/noise/patterns.py | 30 ++--- libp2p/security/secio/transport.py | 104 ++++------------ .../connection.py => secure_session.py} | 42 +++---- 8 files changed, 150 insertions(+), 196 deletions(-) create mode 100644 libp2p/security/insecure/io.py rename libp2p/security/{noise/connection.py => secure_session.py} (63%) diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py index f2e5020..b7be31f 100644 --- a/libp2p/io/abc.py +++ b/libp2p/io/abc.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod class Closer(ABC): + @abstractmethod async def close(self) -> None: ... @@ -39,10 +40,6 @@ class MsgReader(ABC): async def read_msg(self) -> bytes: ... - # @abstractmethod - # async def next_msg_len(self) -> int: - # ... - class MsgWriter(ABC): @abstractmethod @@ -50,7 +47,7 @@ class MsgWriter(ABC): ... -class MsgReadWriter(MsgReader, MsgWriter): +class MsgReadWriteCloser(MsgReader, MsgWriter, Closer): pass @@ -64,5 +61,5 @@ class Encrypter(ABC): ... -class EncryptedMsgReadWriter(MsgReadWriter, Encrypter): - pass +class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter): + """Read/write message with encryption/decryption.""" diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index a9610a4..38710ff 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -8,10 +8,9 @@ NOTE: currently missing the capability to indicate lengths by "varint" method. # TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 from typing import Optional -from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser +from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser from libp2p.io.utils import read_exactly - BYTE_ORDER = "big" @@ -26,12 +25,12 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes: except OverflowError: raise ValueError( "msg_bytes is too large for `size_len_bytes` bytes length: " - f"msg_bytes={msg_bytes}, size_len_bytes={size_len_bytes}" + f"msg_bytes={msg_bytes!r}, size_len_bytes={size_len_bytes}" ) return len_prefix + msg_bytes -class BaseMsgReadWriter(MsgReadWriter): +class BaseMsgReadWriter(MsgReadWriteCloser): next_length: Optional[int] read_write_closer: ReadWriteCloser size_len_bytes: int diff --git a/libp2p/security/insecure/io.py b/libp2p/security/insecure/io.py new file mode 100644 index 0000000..1cbff36 --- /dev/null +++ b/libp2p/security/insecure/io.py @@ -0,0 +1,19 @@ +from libp2p.io.abc import MsgReadWriteCloser, ReadWriteCloser +from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed + + +class PlaintextHandshakeReadWriter(MsgReadWriteCloser): + conn: ReadWriteCloser + + def __init__(self, conn: ReadWriteCloser) -> None: + self.conn = conn + + async def read_msg(self) -> bytes: + return await read_fixedint_prefixed(self.conn) + + async def write_msg(self, msg: bytes) -> None: + encoded_msg_bytes = encode_fixedint_prefixed(msg) + await self.conn.write(encoded_msg_bytes) + + async def close(self) -> None: + await self.conn.close() diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index f452e53..28e7641 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -13,8 +13,8 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.exceptions import HandshakeFailure from libp2p.security.secure_conn_interface import ISecureConn from libp2p.typing import TProtocol -from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed +from .io import PlaintextHandshakeReadWriter from .pb import plaintext_pb2 # Reference: https://github.com/libp2p/go-libp2p-core/blob/master/sec/insecure/insecure.go @@ -44,60 +44,66 @@ class InsecureSession(BaseSession): async def close(self) -> None: await self.conn.close() - async def run_handshake(self) -> None: - """Raise `HandshakeFailure` when handshake failed.""" - msg = make_exchange_message(self.local_private_key.get_public_key()) - msg_bytes = msg.SerializeToString() - encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) - try: - await self.write(encoded_msg_bytes) - except RawConnError as e: - raise HandshakeFailure("connection closed") from e - try: - remote_msg_bytes = await read_fixedint_prefixed(self.conn) - except RawConnError as e: - raise HandshakeFailure("connection closed") from e - remote_msg = plaintext_pb2.Exchange() - remote_msg.ParseFromString(remote_msg_bytes) - received_peer_id = ID(remote_msg.id) +async def run_handshake( + local_peer: ID, + local_private_key: PrivateKey, + conn: IRawConnection, + is_initiator: bool, + remote_peer_id: ID, +) -> ISecureConn: + """Raise `HandshakeFailure` when handshake failed.""" + msg = make_exchange_message(local_private_key.get_public_key()) + msg_bytes = msg.SerializeToString() + read_writer = PlaintextHandshakeReadWriter(conn) + try: + await read_writer.write_msg(msg_bytes) + except RawConnError as e: + raise HandshakeFailure("connection closed") from e - # Verify if the receive `ID` matches the one we originally initialize the session. - # We only need to check it when we are the initiator, because only in that condition - # we possibly knows the `ID` of the remote. - if self.is_initiator and self.remote_peer_id != received_peer_id: - raise HandshakeFailure( - "remote peer sent unexpected peer ID. " - f"expected={self.remote_peer_id} received={received_peer_id}" - ) + try: + remote_msg_bytes = await read_writer.read_msg() + except RawConnError as e: + raise HandshakeFailure("connection closed") from e + remote_msg = plaintext_pb2.Exchange() + remote_msg.ParseFromString(remote_msg_bytes) + received_peer_id = ID(remote_msg.id) - # Verify if the given `pubkey` matches the given `peer_id` - try: - received_pubkey = deserialize_public_key( - remote_msg.pubkey.SerializeToString() - ) - except ValueError as e: - raise HandshakeFailure( - f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}" - ) from e - except MissingDeserializerError as error: - raise HandshakeFailure() from error - peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey) - if peer_id_from_received_pubkey != received_peer_id: - raise HandshakeFailure( - "peer id and pubkey from the remote mismatch: " - f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, " - f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}" - ) + # Verify if the receive `ID` matches the one we originally initialize the session. + # We only need to check it when we are the initiator, because only in that condition + # we possibly knows the `ID` of the remote. + if is_initiator and remote_peer_id != received_peer_id: + raise HandshakeFailure( + "remote peer sent unexpected peer ID. " + f"expected={remote_peer_id} received={received_peer_id}" + ) - # Nothing is wrong. Store the `pubkey` and `peer_id` in the session. - self.remote_permanent_pubkey = received_pubkey - # Only need to set peer's id when we don't know it before, - # i.e. we are not the connection initiator. - if not self.is_initiator: - self.remote_peer_id = received_peer_id + # Verify if the given `pubkey` matches the given `peer_id` + try: + received_pubkey = deserialize_public_key(remote_msg.pubkey.SerializeToString()) + except ValueError as e: + raise HandshakeFailure( + f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}" + ) from e + except MissingDeserializerError as error: + raise HandshakeFailure() from error + peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey) + if peer_id_from_received_pubkey != received_peer_id: + raise HandshakeFailure( + "peer id and pubkey from the remote mismatch: " + f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, " + f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}" + ) - # TODO: Store `pubkey` and `peer_id` to `PeerStore` + secure_conn = InsecureSession( + local_peer, local_private_key, conn, is_initiator, received_peer_id + ) + + # Nothing is wrong. Store the `pubkey` and `peer_id` in the session. + secure_conn.remote_permanent_pubkey = received_pubkey + # TODO: Store `pubkey` and `peer_id` to `PeerStore` + + return secure_conn class InsecureTransport(BaseSecureTransport): @@ -113,9 +119,9 @@ class InsecureTransport(BaseSecureTransport): :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession(self.local_peer, self.local_private_key, conn, False) - await session.run_handshake() - return session + return await run_handshake( + self.local_peer, self.local_private_key, conn, False, None + ) async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: """ @@ -124,11 +130,9 @@ class InsecureTransport(BaseSecureTransport): :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession( + return await run_handshake( self.local_peer, self.local_private_key, conn, True, peer_id ) - await session.run_handshake() - return session def make_exchange_message(pubkey: PublicKey) -> plaintext_pb2.Exchange: diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 4b01f9d..5ffeafe 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -1,14 +1,11 @@ -from abc import ABC, abstractmethod from typing import cast from noise.connection import NoiseConnection as NoiseState -from libp2p.io.abc import ReadWriteCloser, MsgReadWriter, EncryptedMsgReadWriter +from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length -from libp2p.io.utils import read_exactly from libp2p.network.connection.raw_connection_interface import IRawConnection - SIZE_NOISE_MESSAGE_LEN = 2 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 SIZE_NOISE_MESSAGE_BODY_LEN = 2 @@ -50,7 +47,14 @@ def decode_msg_body(noise_msg: bytes) -> bytes: class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): - read_writer: MsgReadWriter + """ + The base implementation of noise message reader/writer. + + `encrypt` and `decrypt` are not implemented here, which should be + implemented by the subclasses. + """ + + read_writer: MsgReadWriteCloser noise_state: NoiseState def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: @@ -67,6 +71,9 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): noise_msg = self.decrypt(noise_msg_encrypted) return decode_msg_body(noise_msg) + async def close(self) -> None: + await self.read_writer.close() + class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter): def encrypt(self, data: bytes) -> bytes: diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 213b157..64d4965 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -8,15 +8,15 @@ from libp2p.crypto.keys import PrivateKey from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.security.secure_session import SecureSession -from .connection import NoiseConnection from .exceptions import ( HandshakeHasNotFinished, InvalidSignature, NoiseStateError, PeerIDMismatchesPubkey, ) -from .io import encode_msg_body, decode_msg_body, NoiseHandshakeReadWriter +from .io import NoiseHandshakeReadWriter, NoiseTransportReadWriter from .messages import ( NoiseHandshakePayload, make_handshake_payload_sig, @@ -56,16 +56,6 @@ class BasePattern(IPattern): ) return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature) - async def write_msg(self, conn: IRawConnection, data: bytes) -> None: - noise_msg = encode_msg_body(data) - data_encrypted = self.noise_state.write_message(noise_msg) - await self.read_writer.write_msg(data_encrypted) - - async def read_msg(self) -> bytes: - noise_msg_encrypted = await self.read_writer.read_msg() - noise_msg = self.noise_state.read_message(noise_msg_encrypted) - return decode_msg_body(noise_msg) - class PatternXX(BasePattern): def __init__( @@ -116,14 +106,13 @@ class PatternXX(BasePattern): raise HandshakeHasNotFinished( "handshake is done but it is not marked as finished in `noise_state`" ) - - return NoiseConnection( + transport_read_writer = NoiseTransportReadWriter(conn, noise_state) + return SecureSession( self.local_peer, self.libp2p_privkey, remote_peer_id_from_pubkey, - conn, + transport_read_writer, False, - noise_state, ) async def handshake_outbound( @@ -171,7 +160,12 @@ class PatternXX(BasePattern): raise HandshakeHasNotFinished( "handshake is done but it is not marked as finished in `noise_state`" ) + transport_read_writer = NoiseTransportReadWriter(conn, noise_state) - return NoiseConnection( - self.local_peer, self.libp2p_privkey, remote_peer, conn, False, noise_state + return SecureSession( + self.local_peer, + self.libp2p_privkey, + remote_peer, + transport_read_writer, + False, ) diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 109f488..ba774d9 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -import io import itertools from typing import Optional, Tuple @@ -18,13 +17,14 @@ from libp2p.crypto.exceptions import MissingDeserializerError from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key +from libp2p.io.abc import EncryptedMsgReadWriter from libp2p.io.exceptions import DecryptionFailedException, IOException from libp2p.io.msgio import BaseMsgReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID -from libp2p.security.base_session import BaseSession from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.security.secure_session import SecureSession from libp2p.typing import TProtocol from .exceptions import ( @@ -54,30 +54,20 @@ class MsgIOReadWriter(BaseMsgReadWriter): size_len_bytes = SIZE_SECIO_LEN_BYTES -class SecureSession(BaseSession): - buf: io.BytesIO - low_watermark: int - high_watermark: int +class SecioMsgReadWriter(EncryptedMsgReadWriter): + read_writer: MsgIOReadWriter def __init__( self, - local_peer: PeerID, - local_private_key: PrivateKey, local_encryption_parameters: AuthenticatedEncryptionParameters, - remote_peer: PeerID, remote_encryption_parameters: AuthenticatedEncryptionParameters, - conn: MsgIOReadWriter, - is_initiator: bool, + read_writer: MsgIOReadWriter, ) -> None: - super().__init__(local_peer, local_private_key, is_initiator, remote_peer) - self.conn = conn - self.local_encryption_parameters = local_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_remote_peer() - - self._reset_internal_buffer() + self.read_writer = read_writer def _initialize_authenticated_encryption_for_local_peer(self) -> None: self.local_encrypter = Encrypter(self.local_encryption_parameters) @@ -85,68 +75,28 @@ class SecureSession(BaseSession): def _initialize_authenticated_encryption_for_remote_peer(self) -> None: self.remote_encrypter = Encrypter(self.remote_encryption_parameters) - async def next_msg_len(self) -> int: - return await self.conn.next_msg_len() + def encrypt(self, data: bytes) -> bytes: + encrypted_data = self.local_encrypter.encrypt(data) + tag = self.local_encrypter.authenticate(encrypted_data) + return encrypted_data + tag - def _reset_internal_buffer(self) -> None: - self.buf = io.BytesIO() - self.low_watermark = 0 - self.high_watermark = 0 - - def _drain(self, n: int) -> bytes: - if self.low_watermark == self.high_watermark: - return bytes() - - data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] - - if n is None: - n = len(data) - result = data[:n].tobytes() - self.low_watermark += len(result) - - if self.low_watermark == self.high_watermark: - del data # free the memoryview so we can free the underlying BytesIO - self.buf.close() - self._reset_internal_buffer() - return result - - async def _fill(self) -> None: - msg = await self.read_msg() - self.buf.write(msg) - self.low_watermark = 0 - self.high_watermark = len(msg) - - async def read(self, n: int = None) -> bytes: - if n == 0: - return bytes() - - data_from_buffer = self._drain(n) - if len(data_from_buffer) > 0: - return data_from_buffer - - next_length = await self.next_msg_len() - - if n < next_length: - await self._fill() - return self._drain(n) - else: - return await self.read_msg() - - async def read_msg(self) -> bytes: - msg = await self.conn.read_msg() + def decrypt(self, data: bytes) -> bytes: try: - decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg) + decrypted_data = self.remote_encrypter.decrypt_if_valid(data) except InvalidMACException as e: raise DecryptionFailedException() from e - return decrypted_msg - - async def write(self, data: bytes) -> None: - await self.write_msg(data) + return decrypted_data async def write_msg(self, msg: bytes) -> None: - encrypted_data = self.local_encrypter.encrypt(msg) - tag = self.local_encrypter.authenticate(encrypted_data) - await self.conn.write_msg(encrypted_data + tag) + data_encrypted = self.encrypt(msg) + await self.read_writer.write_msg(data_encrypted) + + async def read_msg(self) -> bytes: + msg_encrypted = await self.read_writer.read_msg() + return self.decrypt(msg_encrypted) + + async def close(self) -> None: + await self.read_writer.close() @dataclass(frozen=True) @@ -387,22 +337,20 @@ def _mk_session_from( if session_parameters.order < 0: key_set1, key_set2 = key_set2, key_set1 - + secio_read_writer = SecioMsgReadWriter(key_set1, key_set2, conn) session = SecureSession( session_parameters.local_peer, local_private_key, - key_set1, session_parameters.remote_peer, - key_set2, - conn, + secio_read_writer, is_initiator, ) return session async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes: - await session.write_msg(remote_nonce) - return await session.read_msg() + await session.conn.write_msg(remote_nonce) + return await session.conn.read_msg() async def create_secure_session( diff --git a/libp2p/security/noise/connection.py b/libp2p/security/secure_session.py similarity index 63% rename from libp2p/security/noise/connection.py rename to libp2p/security/secure_session.py index 29bbc8b..9bbc00a 100644 --- a/libp2p/security/noise/connection.py +++ b/libp2p/security/secure_session.py @@ -1,43 +1,29 @@ import io -from noise.connection import NoiseConnection as NoiseState - from libp2p.crypto.keys import PrivateKey -from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.io.abc import EncryptedMsgReadWriter from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession -from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter -class NoiseConnection(BaseSession): +class SecureSession(BaseSession): buf: io.BytesIO low_watermark: int high_watermark: int - read_writer: IRawConnection - noise_state: NoiseState - def __init__( self, local_peer: ID, local_private_key: PrivateKey, remote_peer: ID, - conn: IRawConnection, + conn: EncryptedMsgReadWriter, is_initiator: bool, - noise_state: NoiseState, - # remote_permanent_pubkey ) -> None: super().__init__(local_peer, local_private_key, is_initiator, remote_peer) self.conn = conn - self.noise_state = noise_state + self._reset_internal_buffer() - def get_msg_read_writer(self) -> MsgReadWriter: - return NoiseTransportReadWriter(self.conn, self.noise_state) - - async def close(self) -> None: - await self.conn.close() - def _reset_internal_buffer(self) -> None: self.buf = io.BytesIO() self.low_watermark = 0 @@ -60,6 +46,11 @@ class NoiseConnection(BaseSession): self._reset_internal_buffer() return result + def _fill(self, msg: bytes) -> None: + self.buf.write(msg) + self.low_watermark = 0 + self.high_watermark = len(msg) + async def read(self, n: int = None) -> bytes: if n == 0: return bytes() @@ -68,21 +59,16 @@ class NoiseConnection(BaseSession): if len(data_from_buffer) > 0: return data_from_buffer - msg = await self.read_msg() + msg = await self.conn.read_msg() if n < len(msg): - self.buf.write(msg) - self.low_watermark = 0 - self.high_watermark = len(msg) + self._fill(msg) return self._drain(n) else: return msg - async def read_msg(self) -> bytes: - return await self.get_msg_read_writer().read_msg() - async def write(self, data: bytes) -> None: - await self.write_msg(data) + await self.conn.write_msg(data) - async def write_msg(self, msg: bytes) -> None: - await self.get_msg_read_writer().write_msg(msg) + async def close(self) -> None: + await self.conn.close()