diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 3277901..0d20de5 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -39,9 +39,12 @@ class RawConnection(IRawConnection): async with self._drain_lock: await self.writer.drain() - async def read(self) -> bytes: - line = await self.reader.readline() - return line.rstrip(b"\n") + async def read(self, n: int = -1) -> bytes: + """ + Read up to ``n`` bytes from the underlying stream. + This call is delegated directly to the underlying ``self.reader``. + """ + return await self.reader.read(n) def close(self) -> None: self.writer.close() diff --git a/libp2p/network/connection/raw_connection_interface.py b/libp2p/network/connection/raw_connection_interface.py index 1810f58..fd1b469 100644 --- a/libp2p/network/connection/raw_connection_interface.py +++ b/libp2p/network/connection/raw_connection_interface.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -import asyncio class IRawConnection(ABC): @@ -9,17 +8,12 @@ class IRawConnection(ABC): initiator: bool - # TODO: reader and writer shouldn't be exposed. - # Need better API for the consumers - reader: asyncio.StreamReader - writer: asyncio.StreamWriter - @abstractmethod async def write(self, data: bytes) -> None: pass @abstractmethod - async def read(self) -> bytes: + async def read(self, n: int = -1) -> bytes: pass @abstractmethod diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index e01e9cc..e252304 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -16,7 +16,7 @@ class RawConnectionCommunicator(IMultiselectCommunicator): await self.conn.write(msg_bytes) async def read(self) -> str: - data = await read_delim(self.conn.reader) + data = await read_delim(self.conn) return data.decode() diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index ba14037..a41c52b 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -23,8 +23,6 @@ class BaseSession(ISecureConn): self.remote_permanent_pubkey = None self.initiator = self.conn.initiator - self.writer = self.conn.writer - self.reader = self.conn.reader # TODO clean up how this is passed around? def next_stream_id(self) -> int: @@ -33,8 +31,8 @@ class BaseSession(ISecureConn): async def write(self, data: bytes) -> None: await self.conn.write(data) - async def read(self) -> bytes: - return await self.conn.read() + async def read(self, n: int = -1) -> bytes: + return await self.conn.read(n) def close(self) -> None: self.conn.close() diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index fa4a1a8..c16713b 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -23,7 +23,7 @@ class InsecureSession(BaseSession): encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) await self.write(encoded_msg_bytes) - msg_bytes_other_side = await read_fixedint_prefixed(self.reader) + msg_bytes_other_side = await read_fixedint_prefixed(self.conn) msg_other_side = plaintext_pb2.Exchange() msg_other_side.ParseFromString(msg_bytes_other_side) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index ec24b46..6e69d7a 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -87,10 +87,6 @@ class SecurityMultistream(ABC): :param initiator: true if we are the initiator, false otherwise :return: selected secure transport """ - # TODO: Is conn acceptable to multiselect/multiselect_client - # instead of stream? In go repo, they pass in a raw conn - # (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go) - protocol: TProtocol communicator = RawConnectionCommunicator(conn) if initiator: diff --git a/libp2p/security/simple/transport.py b/libp2p/security/simple/transport.py index e63e651..e70edcc 100644 --- a/libp2p/security/simple/transport.py +++ b/libp2p/security/simple/transport.py @@ -7,6 +7,7 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import InsecureSession from libp2p.security.secure_conn_interface import ISecureConn from libp2p.transport.exceptions import SecurityUpgradeFailure +from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed class SimpleSecurityTransport(BaseSecureTransport): @@ -22,8 +23,8 @@ class SimpleSecurityTransport(BaseSecureTransport): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - await conn.write(self.key_phrase.encode()) - incoming = (await conn.read()).decode() + await conn.write(encode_fixedint_prefixed(self.key_phrase.encode())) + incoming = (await read_fixedint_prefixed(conn)).decode() if incoming != self.key_phrase: raise SecurityUpgradeFailure( @@ -48,8 +49,8 @@ class SimpleSecurityTransport(BaseSecureTransport): for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ - await conn.write(self.key_phrase.encode()) - incoming = (await conn.read()).decode() + await conn.write(encode_fixedint_prefixed(self.key_phrase.encode())) + incoming = (await read_fixedint_prefixed(conn)).decode() # Force context switch, as this security transport is built for testing locally # in a single event loop diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 765dd56..8f95124 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -188,11 +188,9 @@ class Mplex(IMuxedConn): # loop in handle_incoming timeout = 0.1 try: - header = await decode_uvarint_from_stream(self.conn.reader, timeout) - length = await decode_uvarint_from_stream(self.conn.reader, timeout) - message = await asyncio.wait_for( - self.conn.reader.read(length), timeout=timeout - ) + header = await decode_uvarint_from_stream(self.conn, timeout) + length = await decode_uvarint_from_stream(self.conn, timeout) + message = await asyncio.wait_for(self.conn.read(length), timeout=timeout) except asyncio.TimeoutError: return None, None, None diff --git a/libp2p/typing.py b/libp2p/typing.py index 9810746..f36d8ab 100644 --- a/libp2p/typing.py +++ b/libp2p/typing.py @@ -1,6 +1,7 @@ -import asyncio from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union +from libp2p.network.connection.raw_connection_interface import IRawConnection + if TYPE_CHECKING: from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401 from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401 @@ -9,4 +10,4 @@ TProtocol = NewType("TProtocol", str) StreamHandlerFn = Callable[["INetStream"], Awaitable[None]] -StreamReader = Union["IMuxedStream", asyncio.StreamReader] +StreamReader = Union["IMuxedStream", IRawConnection]