From 6016ea731b9830bce82539e0eead96b282807527 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 20 Feb 2020 21:48:03 +0800 Subject: [PATCH] `BaseMsgReadWriter` - Change `BaseMsgReadWriter` to encode/decode messages with abstract method, which can be implemented by the subclasses. This allows us to create subclasses `FixedSizeLenMsgReadWriter` and `VarIntLenMsgReadWriter`. --- libp2p/io/exceptions.py | 4 ++ libp2p/io/msgio.py | 61 +++++++++++++++++++-------- libp2p/security/insecure/transport.py | 4 +- libp2p/security/noise/io.py | 4 +- libp2p/security/secio/transport.py | 16 +++---- 5 files changed, 59 insertions(+), 30 deletions(-) diff --git a/libp2p/io/exceptions.py b/libp2p/io/exceptions.py index 0f2230f..2c237c9 100644 --- a/libp2p/io/exceptions.py +++ b/libp2p/io/exceptions.py @@ -23,3 +23,7 @@ class MissingMessageException(MsgioException): class DecryptionFailedException(MsgioException): pass + + +class MessageTooLarge(MsgioException): + pass diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index 837f642..457f055 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -5,11 +5,13 @@ from that repo: "a simple package to r/w length-delimited slices." NOTE: currently missing the capability to indicate lengths by "varint" method. """ - -from typing import Optional +from abc import abstractmethod from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser from libp2p.io.utils import read_exactly +from libp2p.utils import decode_uvarint_from_stream, encode_varint_prefixed + +from .exceptions import MessageTooLarge BYTE_ORDER = "big" @@ -31,34 +33,57 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes: class BaseMsgReadWriter(MsgReadWriteCloser): - next_length: Optional[int] read_write_closer: ReadWriteCloser size_len_bytes: int def __init__(self, read_write_closer: ReadWriteCloser) -> None: self.read_write_closer = read_write_closer - self.next_length = None async def read_msg(self) -> bytes: length = await self.next_msg_len() + return await read_exactly(self.read_write_closer, length) - data = await read_exactly(self.read_write_closer, length) - if len(data) < length: - self.next_length = length - len(data) - else: - self.next_length = None - return data - + @abstractmethod async def next_msg_len(self) -> int: - if self.next_length is None: - self.next_length = await read_length( - self.read_write_closer, self.size_len_bytes - ) - return self.next_length + ... + + @abstractmethod + def encode_msg(self, msg: bytes) -> bytes: + ... async def close(self) -> None: await self.read_write_closer.close() async def write_msg(self, msg: bytes) -> None: - data = encode_msg_with_length(msg, self.size_len_bytes) - await self.read_write_closer.write(data) + encoded_msg = self.encode_msg(msg) + await self.read_write_closer.write(encoded_msg) + + +class FixedSizeLenMsgReadWriter(BaseMsgReadWriter): + size_len_bytes: int + + async def next_msg_len(self) -> int: + return await read_length(self.read_write_closer, self.size_len_bytes) + + def encode_msg(self, msg: bytes) -> bytes: + return encode_msg_with_length(msg, self.size_len_bytes) + + +class VarIntLengthMsgReadWriter(BaseMsgReadWriter): + max_msg_size: int + + async def next_msg_len(self) -> int: + msg_len = await decode_uvarint_from_stream(self.read_write_closer) + if msg_len > self.max_msg_size: + raise MessageTooLarge( + f"msg_len={msg_len} > max_msg_size={self.max_msg_size}" + ) + return msg_len + + def encode_msg(self, msg: bytes) -> bytes: + msg_len = len(msg) + if msg_len > self.max_msg_size: + raise MessageTooLarge( + f"msg_len={msg_len} > max_msg_size={self.max_msg_size}" + ) + return encode_varint_prefixed(msg) diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 861ca71..42b7bf8 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -3,7 +3,7 @@ from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.serialization import deserialize_public_key from libp2p.io.abc import ReadWriteCloser -from libp2p.io.msgio import BaseMsgReadWriter +from libp2p.io.msgio import FixedSizeLenMsgReadWriter from libp2p.network.connection.exceptions import RawConnError from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -23,7 +23,7 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0") SIZE_PLAINTEXT_LEN_BYTES = 4 -class PlaintextHandshakeReadWriter(BaseMsgReadWriter): +class PlaintextHandshakeReadWriter(FixedSizeLenMsgReadWriter): size_len_bytes = SIZE_PLAINTEXT_LEN_BYTES diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 5ffeafe..499bbeb 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -3,7 +3,7 @@ from typing import cast from noise.connection import NoiseConnection as NoiseState from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser -from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length +from libp2p.io.msgio import FixedSizeLenMsgReadWriter, encode_msg_with_length from libp2p.network.connection.raw_connection_interface import IRawConnection SIZE_NOISE_MESSAGE_LEN = 2 @@ -19,7 +19,7 @@ BYTE_ORDER = "big" # <-2 bytes-><- max=65533 bytes -> -class NoisePacketReadWriter(BaseMsgReadWriter): +class NoisePacketReadWriter(FixedSizeLenMsgReadWriter): size_len_bytes = SIZE_NOISE_MESSAGE_LEN diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 4759ccc..46f5c2d 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -19,7 +19,7 @@ from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key from libp2p.io.abc import EncryptedMsgReadWriter from libp2p.io.exceptions import DecryptionFailedException, IOException -from libp2p.io.msgio import BaseMsgReadWriter +from libp2p.io.msgio import FixedSizeLenMsgReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID from libp2p.security.base_transport import BaseSecureTransport @@ -50,18 +50,18 @@ DEFAULT_SUPPORTED_CIPHERS = "AES-128" DEFAULT_SUPPORTED_HASHES = "SHA256" -class MsgIOReadWriter(BaseMsgReadWriter): +class SecioPacketReadWriter(FixedSizeLenMsgReadWriter): size_len_bytes = SIZE_SECIO_LEN_BYTES class SecioMsgReadWriter(EncryptedMsgReadWriter): - read_writer: MsgIOReadWriter + read_writer: SecioPacketReadWriter def __init__( self, local_encryption_parameters: AuthenticatedEncryptionParameters, remote_encryption_parameters: AuthenticatedEncryptionParameters, - read_writer: MsgIOReadWriter, + read_writer: SecioPacketReadWriter, ) -> None: self.local_encryption_parameters = local_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters @@ -170,7 +170,7 @@ class SessionParameters: pass -async def _response_to_msg(read_writer: MsgIOReadWriter, msg: bytes) -> bytes: +async def _response_to_msg(read_writer: SecioPacketReadWriter, msg: bytes) -> bytes: await read_writer.write_msg(msg) return await read_writer.read_msg() @@ -234,7 +234,7 @@ async def _establish_session_parameters( local_peer: PeerID, local_private_key: PrivateKey, remote_peer: Optional[PeerID], - conn: MsgIOReadWriter, + conn: SecioPacketReadWriter, nonce: bytes, ) -> Tuple[SessionParameters, bytes]: # establish shared encryption parameters @@ -326,7 +326,7 @@ async def _establish_session_parameters( def _mk_session_from( local_private_key: PrivateKey, session_parameters: SessionParameters, - conn: MsgIOReadWriter, + conn: SecioPacketReadWriter, is_initiator: bool, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( @@ -371,7 +371,7 @@ async def create_secure_session( to the ``remote_peer``. Raise `SecioException` when `conn` closed. Raise `InconsistentNonce` when handshake failed """ - msg_io = MsgIOReadWriter(conn) + msg_io = SecioPacketReadWriter(conn) try: session_parameters, remote_nonce = await _establish_session_parameters( local_peer, local_private_key, remote_peer, msg_io, local_nonce