Use RawConnection.read
Instead of accessing its reader and writer directly. TODO: considering add `ReaderWriterCloser` interface and let connection and stream inherit from it.
This commit is contained in:
parent
0b466ddc86
commit
ef476e555b
@ -39,9 +39,12 @@ class RawConnection(IRawConnection):
|
||||
async with self._drain_lock:
|
||||
await self.writer.drain()
|
||||
|
||||
async def read(self) -> bytes:
|
||||
line = await self.reader.readline()
|
||||
return line.rstrip(b"\n")
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
Read up to ``n`` bytes from the underlying stream.
|
||||
This call is delegated directly to the underlying ``self.reader``.
|
||||
"""
|
||||
return await self.reader.read(n)
|
||||
|
||||
def close(self) -> None:
|
||||
self.writer.close()
|
||||
|
@ -1,5 +1,4 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
|
||||
|
||||
class IRawConnection(ABC):
|
||||
@ -9,17 +8,12 @@ class IRawConnection(ABC):
|
||||
|
||||
initiator: bool
|
||||
|
||||
# TODO: reader and writer shouldn't be exposed.
|
||||
# Need better API for the consumers
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read(self) -> bytes:
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -16,7 +16,7 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
|
||||
await self.conn.write(msg_bytes)
|
||||
|
||||
async def read(self) -> str:
|
||||
data = await read_delim(self.conn.reader)
|
||||
data = await read_delim(self.conn)
|
||||
return data.decode()
|
||||
|
||||
|
||||
|
@ -23,8 +23,6 @@ class BaseSession(ISecureConn):
|
||||
self.remote_permanent_pubkey = None
|
||||
|
||||
self.initiator = self.conn.initiator
|
||||
self.writer = self.conn.writer
|
||||
self.reader = self.conn.reader
|
||||
|
||||
# TODO clean up how this is passed around?
|
||||
def next_stream_id(self) -> int:
|
||||
@ -33,8 +31,8 @@ class BaseSession(ISecureConn):
|
||||
async def write(self, data: bytes) -> None:
|
||||
await self.conn.write(data)
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return await self.conn.read()
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
return await self.conn.read(n)
|
||||
|
||||
def close(self) -> None:
|
||||
self.conn.close()
|
||||
|
@ -23,7 +23,7 @@ class InsecureSession(BaseSession):
|
||||
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
||||
await self.write(encoded_msg_bytes)
|
||||
|
||||
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
|
||||
msg_bytes_other_side = await read_fixedint_prefixed(self.conn)
|
||||
msg_other_side = plaintext_pb2.Exchange()
|
||||
msg_other_side.ParseFromString(msg_bytes_other_side)
|
||||
|
||||
|
@ -87,10 +87,6 @@ class SecurityMultistream(ABC):
|
||||
:param initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
# TODO: Is conn acceptable to multiselect/multiselect_client
|
||||
# instead of stream? In go repo, they pass in a raw conn
|
||||
# (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go)
|
||||
|
||||
protocol: TProtocol
|
||||
communicator = RawConnectionCommunicator(conn)
|
||||
if initiator:
|
||||
|
@ -7,6 +7,7 @@ from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import InsecureSession
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
from libp2p.transport.exceptions import SecurityUpgradeFailure
|
||||
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
|
||||
|
||||
|
||||
class SimpleSecurityTransport(BaseSecureTransport):
|
||||
@ -22,8 +23,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
|
||||
for an inbound connection (i.e. we are not the initiator)
|
||||
:return: secure connection object (that implements secure_conn_interface)
|
||||
"""
|
||||
await conn.write(self.key_phrase.encode())
|
||||
incoming = (await conn.read()).decode()
|
||||
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||
|
||||
if incoming != self.key_phrase:
|
||||
raise SecurityUpgradeFailure(
|
||||
@ -48,8 +49,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
|
||||
for an inbound connection (i.e. we are the initiator)
|
||||
:return: secure connection object (that implements secure_conn_interface)
|
||||
"""
|
||||
await conn.write(self.key_phrase.encode())
|
||||
incoming = (await conn.read()).decode()
|
||||
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||
|
||||
# Force context switch, as this security transport is built for testing locally
|
||||
# in a single event loop
|
||||
|
@ -188,11 +188,9 @@ class Mplex(IMuxedConn):
|
||||
# loop in handle_incoming
|
||||
timeout = 0.1
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.conn.reader, timeout)
|
||||
length = await decode_uvarint_from_stream(self.conn.reader, timeout)
|
||||
message = await asyncio.wait_for(
|
||||
self.conn.reader.read(length), timeout=timeout
|
||||
)
|
||||
header = await decode_uvarint_from_stream(self.conn, timeout)
|
||||
length = await decode_uvarint_from_stream(self.conn, timeout)
|
||||
message = await asyncio.wait_for(self.conn.read(length), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None, None, None
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union
|
||||
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401
|
||||
from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401
|
||||
@ -9,4 +10,4 @@ TProtocol = NewType("TProtocol", str)
|
||||
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
|
||||
|
||||
|
||||
StreamReader = Union["IMuxedStream", asyncio.StreamReader]
|
||||
StreamReader = Union["IMuxedStream", IRawConnection]
|
||||
|
Loading…
x
Reference in New Issue
Block a user