diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py new file mode 100644 index 0000000..eea7b72 --- /dev/null +++ b/libp2p/io/abc.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod + + +class Closer(ABC): + async def close(self) -> None: + ... + + +class Reader(ABC): + @abstractmethod + async def read(self, n: int = -1) -> bytes: + ... + + +class Writer(ABC): + @abstractmethod + async def write(self, data: bytes) -> int: + ... + + +class WriteCloser(Writer, Closer): + pass + + +class ReadCloser(Reader, Closer): + pass + + +class ReadWriter(Reader, Writer): + pass + + +class ReadWriteCloser(Reader, Writer, Closer): + pass diff --git a/libp2p/io/exceptions.py b/libp2p/io/exceptions.py index 6e1376f..b8e4e01 100644 --- a/libp2p/io/exceptions.py +++ b/libp2p/io/exceptions.py @@ -1,7 +1,17 @@ from libp2p.exceptions import BaseLibp2pError -class MsgioException(BaseLibp2pError): +class IOException(BaseLibp2pError): + pass + + +class IncompleteReadError(IOException): + """ + Fewer bytes were read than requested. + """ + + +class MsgioException(IOException): pass diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index 65fde68..4c23ef2 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -1,24 +1,92 @@ -from libp2p.network.connection.raw_connection_interface import IRawConnection +""" +``msgio`` is an implementation of `https://github.com/libp2p/go-msgio`. -from .exceptions import MissingLengthException, MissingMessageException +from that repo: "a simple package to r/w length-delimited slices." + +NOTE: currently missing the capability to indicate lengths by "varint" method. +""" +# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 +from typing import Optional, cast + +from libp2p.io.abc import ( + Closer, + ReadCloser, + Reader, + ReadWriteCloser, + WriteCloser, + Writer, +) +from libp2p.io.utils import read_exactly SIZE_LEN_BYTES = 4 - -# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 +BYTE_ORDER = "big" -def encode(msg_bytes: bytes) -> bytes: +async def read_length(reader: Reader) -> int: + length_bytes = await read_exactly(reader, SIZE_LEN_BYTES) + return int.from_bytes(length_bytes, byteorder=BYTE_ORDER) + + +def encode_msg_with_length(msg_bytes: bytes) -> bytes: len_prefix = len(msg_bytes).to_bytes(SIZE_LEN_BYTES, "big") return len_prefix + msg_bytes -async def read_next_message(reader: IRawConnection) -> bytes: - len_bytes = await reader.read(SIZE_LEN_BYTES) - if len(len_bytes) != SIZE_LEN_BYTES: - raise MissingLengthException() - len_int = int.from_bytes(len_bytes, "big") - next_msg = await reader.read(len_int) - if len(next_msg) != len_int: - # TODO makes sense to keep reading until this condition is true? - raise MissingMessageException() - return next_msg +class MsgIOWriter(Writer, Closer): + write_closer: WriteCloser + + def __init__(self, write_closer: WriteCloser) -> None: + super().__init__() + self.write_closer = write_closer + + async def write(self, data: bytes) -> int: + await self.write_msg(data) + return len(data) + + async def write_msg(self, msg: bytes) -> None: + data = encode_msg_with_length(msg) + await self.write_closer.write(data) + + async def close(self) -> None: + await self.write_closer.close() + + +class MsgIOReader(Reader, Closer): + read_closer: ReadCloser + next_length: Optional[int] + + def __init__(self, read_closer: ReadCloser) -> None: + # NOTE: the following line is required to satisfy the + # multiple inheritance but `mypy` does not like it... + super().__init__(read_closer) # type: ignore + self.read_closer = read_closer + self.next_length = None + + async def read(self, n: int = -1) -> bytes: + return await self.read_msg() + + async def read_msg(self) -> bytes: + length = await self.next_msg_len() + + data = await read_exactly(self.read_closer, length) + if len(data) < length: + self.next_length = length - len(data) + else: + self.next_length = None + return data + + async def next_msg_len(self) -> int: + if self.next_length is None: + self.next_length = await read_length(self.read_closer) + return self.next_length + + async def close(self) -> None: + await self.read_closer.close() + + +class MsgIOReadWriter(MsgIOReader, MsgIOWriter, Closer): + def __init__(self, read_write_closer: ReadWriteCloser) -> None: + super().__init__(cast(ReadCloser, read_write_closer)) + + async def close(self) -> None: + await self.read_closer.close() diff --git a/libp2p/io/utils.py b/libp2p/io/utils.py new file mode 100644 index 0000000..1a6e0a3 --- /dev/null +++ b/libp2p/io/utils.py @@ -0,0 +1,21 @@ +from libp2p.io.abc import Reader +from libp2p.io.exceptions import IncompleteReadError + +DEFAULT_RETRY_READ_COUNT = 100 + + +async def read_exactly( + reader: Reader, n: int, retry_count: int = DEFAULT_RETRY_READ_COUNT +) -> bytes: + """ + NOTE: relying on exceptions to break out on erroneous conditions, like EOF + """ + data = await reader.read(n) + + for _ in range(retry_count): + if len(data) < n: + remaining = n - len(data) + data += await reader.read(remaining) + else: + return data + raise IncompleteReadError({"requested_count": n, "received_count": len(data)}) diff --git a/libp2p/network/connection/raw_connection_interface.py b/libp2p/network/connection/raw_connection_interface.py index 25b90ae..94951af 100644 --- a/libp2p/network/connection/raw_connection_interface.py +++ b/libp2p/network/connection/raw_connection_interface.py @@ -1,21 +1,9 @@ -from abc import ABC, abstractmethod +from libp2p.io.abc import ReadWriteCloser -class IRawConnection(ABC): +class IRawConnection(ReadWriteCloser): """ A Raw Connection provides a Reader and a Writer """ initiator: bool - - @abstractmethod - async def write(self, data: bytes) -> None: - pass - - @abstractmethod - async def read(self, n: int = -1) -> bytes: - pass - - @abstractmethod - async def close(self) -> None: - pass diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index 2d76198..a7df0e9 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -31,7 +31,7 @@ class BaseSession(ISecureConn): self.remote_permanent_pubkey = None self.conn = conn - self.initiator = self.conn.initiator + self.initiator = peer_id is not None async def write(self, data: bytes) -> None: await self.conn.write(data) diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py index 673cbc5..ca48c23 100644 --- a/tests/security/test_secio.py +++ b/tests/security/test_secio.py @@ -20,11 +20,12 @@ class InMemoryConnection(IRawConnection): self.closed = False - async def write(self, data: bytes) -> None: + async def write(self, data: bytes) -> int: if self.closed: raise Exception("InMemoryConnection is closed for writing") await self.send_queue.put(data) + return len(data) async def read(self, n: int = -1) -> bytes: """