Refactor MsgIOReadWriter
- Abstract it as `MsgReadWriter` - `MsgIOReadWriter` as a subclass of `MsgReadWriter`
This commit is contained in:
parent
ea645f0bd6
commit
874c6bbca4
|
@ -32,3 +32,23 @@ class ReadWriter(Reader, Writer):
|
||||||
|
|
||||||
class ReadWriteCloser(Reader, Writer, Closer):
|
class ReadWriteCloser(Reader, Writer, Closer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MsgReader(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def read_msg(self) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def next_msg_len(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MsgWriter(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def write_msg(self, msg: bytes) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MsgReadWriter(MsgReader, MsgWriter):
|
||||||
|
pass
|
||||||
|
|
|
@ -6,60 +6,39 @@ from that repo: "a simple package to r/w length-delimited slices."
|
||||||
NOTE: currently missing the capability to indicate lengths by "varint" method.
|
NOTE: currently missing the capability to indicate lengths by "varint" method.
|
||||||
"""
|
"""
|
||||||
# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501
|
# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501
|
||||||
from typing import Optional, cast
|
from typing import Optional
|
||||||
|
|
||||||
from libp2p.io.abc import Closer, ReadCloser, Reader, ReadWriteCloser, WriteCloser
|
from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser
|
||||||
from libp2p.io.utils import read_exactly
|
from libp2p.io.utils import read_exactly
|
||||||
|
|
||||||
SIZE_LEN_BYTES = 4
|
SIZE_NOISE_LEN_BYTES = 2
|
||||||
|
SIZE_SECIO_LEN_BYTES = 4
|
||||||
BYTE_ORDER = "big"
|
BYTE_ORDER = "big"
|
||||||
|
|
||||||
|
|
||||||
async def read_length(reader: Reader) -> int:
|
async def read_length(reader: Reader, size_len_bytes: int) -> int:
|
||||||
length_bytes = await read_exactly(reader, SIZE_LEN_BYTES)
|
length_bytes = await read_exactly(reader, size_len_bytes)
|
||||||
return int.from_bytes(length_bytes, byteorder=BYTE_ORDER)
|
return int.from_bytes(length_bytes, byteorder=BYTE_ORDER)
|
||||||
|
|
||||||
|
|
||||||
def encode_msg_with_length(msg_bytes: bytes) -> bytes:
|
def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
|
||||||
len_prefix = len(msg_bytes).to_bytes(SIZE_LEN_BYTES, "big")
|
len_prefix = len(msg_bytes).to_bytes(size_len_bytes, byteorder=BYTE_ORDER)
|
||||||
return len_prefix + msg_bytes
|
return len_prefix + msg_bytes
|
||||||
|
|
||||||
|
|
||||||
class MsgIOWriter(WriteCloser):
|
class BaseMsgReadWriter(MsgReadWriter):
|
||||||
write_closer: WriteCloser
|
|
||||||
|
|
||||||
def __init__(self, write_closer: WriteCloser) -> None:
|
|
||||||
self.write_closer = write_closer
|
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
|
||||||
await self.write_msg(data)
|
|
||||||
|
|
||||||
async def write_msg(self, msg: bytes) -> None:
|
|
||||||
data = encode_msg_with_length(msg)
|
|
||||||
await self.write_closer.write(data)
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
await self.write_closer.close()
|
|
||||||
|
|
||||||
|
|
||||||
class MsgIOReader(ReadCloser):
|
|
||||||
read_closer: ReadCloser
|
|
||||||
next_length: Optional[int]
|
next_length: Optional[int]
|
||||||
|
read_write_closer: ReadWriteCloser
|
||||||
|
size_len_bytes: int
|
||||||
|
|
||||||
def __init__(self, read_closer: ReadCloser) -> None:
|
def __init__(self, read_write_closer: ReadWriteCloser) -> None:
|
||||||
# NOTE: the following line is required to satisfy the
|
self.read_write_closer = read_write_closer
|
||||||
# multiple inheritance but `mypy` does not like it...
|
|
||||||
super().__init__(read_closer) # type: ignore
|
|
||||||
self.read_closer = read_closer
|
|
||||||
self.next_length = None
|
self.next_length = None
|
||||||
|
|
||||||
async def read(self, n: int = None) -> bytes:
|
|
||||||
return await self.read_msg()
|
|
||||||
|
|
||||||
async def read_msg(self) -> bytes:
|
async def read_msg(self) -> bytes:
|
||||||
length = await self.next_msg_len()
|
length = await self.next_msg_len()
|
||||||
|
|
||||||
data = await read_exactly(self.read_closer, length)
|
data = await read_exactly(self.read_write_closer, length)
|
||||||
if len(data) < length:
|
if len(data) < length:
|
||||||
self.next_length = length - len(data)
|
self.next_length = length - len(data)
|
||||||
else:
|
else:
|
||||||
|
@ -68,16 +47,18 @@ class MsgIOReader(ReadCloser):
|
||||||
|
|
||||||
async def next_msg_len(self) -> int:
|
async def next_msg_len(self) -> int:
|
||||||
if self.next_length is None:
|
if self.next_length is None:
|
||||||
self.next_length = await read_length(self.read_closer)
|
self.next_length = await read_length(
|
||||||
|
self.read_write_closer, self.size_len_bytes
|
||||||
|
)
|
||||||
return self.next_length
|
return self.next_length
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.read_closer.close()
|
await self.read_write_closer.close()
|
||||||
|
|
||||||
|
async def write_msg(self, msg: bytes) -> None:
|
||||||
|
data = encode_msg_with_length(msg, self.size_len_bytes)
|
||||||
|
await self.read_write_closer.write(data)
|
||||||
|
|
||||||
|
|
||||||
class MsgIOReadWriter(MsgIOReader, MsgIOWriter, Closer):
|
class MsgIOReadWriter(BaseMsgReadWriter):
|
||||||
def __init__(self, read_write_closer: ReadWriteCloser) -> None:
|
size_len_bytes = SIZE_SECIO_LEN_BYTES
|
||||||
super().__init__(cast(ReadCloser, read_write_closer))
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
await self.read_closer.close()
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user