From 1a359770dd59f67bd5d4b962e91f786bf57842ac Mon Sep 17 00:00:00 2001 From: Alex Stokes Date: Tue, 3 Sep 2019 22:08:09 -0700 Subject: [PATCH] Use `msgio` IO and proper buffering in `secio` implementation --- libp2p/security/secio/transport.py | 96 ++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index e2259e8..10a5763 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import io import itertools from typing import Optional, Tuple @@ -15,8 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key -from libp2p.io.msgio import encode as encode_message -from libp2p.io.msgio import read_next_message +from libp2p.io.msgio import MsgIOReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID from libp2p.security.base_session import BaseSession @@ -45,6 +45,10 @@ DEFAULT_SUPPORTED_HASHES = "SHA256" class SecureSession(BaseSession): + buf: io.BytesIO + low_watermark: int + high_watermark: int + def __init__( self, local_peer: PeerID, @@ -52,7 +56,7 @@ class SecureSession(BaseSession): local_encryption_parameters: AuthenticatedEncryptionParameters, remote_peer: PeerID, remote_encryption_parameters: AuthenticatedEncryptionParameters, - conn: IRawConnection, + conn: MsgIOReadWriter, ) -> None: super().__init__(local_peer, local_private_key, conn, remote_peer) @@ -61,29 +65,70 @@ class SecureSession(BaseSession): self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_remote_peer() + self._reset_internal_buffer() + def _initialize_authenticated_encryption_for_local_peer(self) -> None: self.local_encrypter = Encrypter(self.local_encryption_parameters) def _initialize_authenticated_encryption_for_remote_peer(self) -> None: self.remote_encrypter = Encrypter(self.remote_encryption_parameters) - async def read(self, n: int = -1) -> bytes: - return await self._read_msg() + async def next_msg_len(self) -> int: + return await self.conn.next_msg_len() - async def _read_msg(self) -> bytes: - # TODO do we need to serialize reads? - msg = await read_next_message(self.conn) + def _reset_internal_buffer(self) -> None: + self.buf = io.BytesIO() + self.low_watermark = 0 + self.high_watermark = 0 + + def _drain(self, n: int) -> bytes: + if self.low_watermark == self.high_watermark: + return bytes() + + data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] + + if n < 0: + n = len(data) + result = data[:n].tobytes() + self.low_watermark += len(result) + + if self.low_watermark == self.high_watermark: + del data # free the memoryview so we can free the underlying BytesIO + self.buf.close() + self._reset_internal_buffer() + return result + + async def _fill(self) -> None: + msg = await self.read_msg() + self.buf.write(msg) + self.low_watermark = 0 + self.high_watermark = len(msg) + + async def read(self, n: int = -1) -> bytes: + data_from_buffer = self._drain(n) + if len(data_from_buffer) > 0: + return data_from_buffer + + next_length = await self.next_msg_len() + + if n < next_length: + await self._fill() + return self._drain(n) + else: + return await self.read_msg() + + async def read_msg(self) -> bytes: + msg = await self.conn.read_msg() return self.remote_encrypter.decrypt_if_valid(msg) - async def write(self, data: bytes) -> None: - await self._write_msg(data) + async def write(self, data: bytes) -> int: + await self.write_msg(data) + return len(data) - async def _write_msg(self, data: bytes) -> None: - # TODO do we need to serialize writes? - encrypted_data = self.local_encrypter.encrypt(data) + async def write_msg(self, msg: bytes) -> None: + encrypted_data = self.local_encrypter.encrypt(msg) tag = self.local_encrypter.authenticate(encrypted_data) - msg = encode_message(encrypted_data + tag) - await self.conn.write(msg) + await self.conn.write_msg(encrypted_data + tag) @dataclass(frozen=True) @@ -156,9 +201,9 @@ class SessionParameters: pass -async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: - await conn.write(encode_message(msg)) - return await read_next_message(conn) +async def _response_to_msg(read_writer: MsgIOReadWriter, msg: bytes) -> bytes: + await read_writer.write_msg(msg) + return await read_writer.read_msg() def _mk_multihash_sha256(data: bytes) -> bytes: @@ -220,7 +265,7 @@ async def _establish_session_parameters( local_peer: PeerID, local_private_key: PrivateKey, remote_peer: Optional[PeerID], - conn: IRawConnection, + conn: MsgIOReadWriter, nonce: bytes, ) -> Tuple[SessionParameters, bytes]: # establish shared encryption parameters @@ -312,7 +357,7 @@ async def _establish_session_parameters( def _mk_session_from( local_private_key: PrivateKey, session_parameters: SessionParameters, - conn: IRawConnection, + conn: MsgIOReadWriter, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( session_parameters.local_encryption_parameters.cipher_type, @@ -334,9 +379,9 @@ def _mk_session_from( return session -async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes: - await session.write(remote_nonce) - return await session.read() +async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes: + await session.write_msg(remote_nonce) + return await session.read_msg() async def create_secure_session( @@ -351,15 +396,16 @@ async def create_secure_session( If successful, return an object that provides secure communication to the ``remote_peer``. """ + msg_io = MsgIOReadWriter(conn) try: session_parameters, remote_nonce = await _establish_session_parameters( - local_peer, local_private_key, remote_peer, conn, local_nonce + local_peer, local_private_key, remote_peer, msg_io, local_nonce ) except SecioException as e: await conn.close() raise e - session = _mk_session_from(local_private_key, session_parameters, conn) + session = _mk_session_from(local_private_key, session_parameters, msg_io) received_nonce = await _finish_handshake(session, remote_nonce) if received_nonce != local_nonce: