Use RawConnection.read

Instead of accessing its reader and writer directly.

TODO: considering add `ReaderWriterCloser` interface and let connection
and stream inherit from it.
This commit is contained in:
mhchia 2019-08-20 18:09:36 +08:00
parent 0b466ddc86
commit ef476e555b
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
9 changed files with 22 additions and 31 deletions

View File

@ -39,9 +39,12 @@ class RawConnection(IRawConnection):
async with self._drain_lock:
await self.writer.drain()
async def read(self) -> bytes:
line = await self.reader.readline()
return line.rstrip(b"\n")
async def read(self, n: int = -1) -> bytes:
"""
Read up to ``n`` bytes from the underlying stream.
This call is delegated directly to the underlying ``self.reader``.
"""
return await self.reader.read(n)
def close(self) -> None:
self.writer.close()

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
import asyncio
class IRawConnection(ABC):
@ -9,17 +8,12 @@ class IRawConnection(ABC):
initiator: bool
# TODO: reader and writer shouldn't be exposed.
# Need better API for the consumers
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
@abstractmethod
async def write(self, data: bytes) -> None:
pass
@abstractmethod
async def read(self) -> bytes:
async def read(self, n: int = -1) -> bytes:
pass
@abstractmethod

View File

@ -16,7 +16,7 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
await self.conn.write(msg_bytes)
async def read(self) -> str:
data = await read_delim(self.conn.reader)
data = await read_delim(self.conn)
return data.decode()

View File

@ -23,8 +23,6 @@ class BaseSession(ISecureConn):
self.remote_permanent_pubkey = None
self.initiator = self.conn.initiator
self.writer = self.conn.writer
self.reader = self.conn.reader
# TODO clean up how this is passed around?
def next_stream_id(self) -> int:
@ -33,8 +31,8 @@ class BaseSession(ISecureConn):
async def write(self, data: bytes) -> None:
await self.conn.write(data)
async def read(self) -> bytes:
return await self.conn.read()
async def read(self, n: int = -1) -> bytes:
return await self.conn.read(n)
def close(self) -> None:
self.conn.close()

View File

@ -23,7 +23,7 @@ class InsecureSession(BaseSession):
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
await self.write(encoded_msg_bytes)
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
msg_bytes_other_side = await read_fixedint_prefixed(self.conn)
msg_other_side = plaintext_pb2.Exchange()
msg_other_side.ParseFromString(msg_bytes_other_side)

View File

@ -87,10 +87,6 @@ class SecurityMultistream(ABC):
:param initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
# TODO: Is conn acceptable to multiselect/multiselect_client
# instead of stream? In go repo, they pass in a raw conn
# (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go)
protocol: TProtocol
communicator = RawConnectionCommunicator(conn)
if initiator:

View File

@ -7,6 +7,7 @@ from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import InsecureSession
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.exceptions import SecurityUpgradeFailure
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
class SimpleSecurityTransport(BaseSecureTransport):
@ -22,8 +23,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
for an inbound connection (i.e. we are not the initiator)
:return: secure connection object (that implements secure_conn_interface)
"""
await conn.write(self.key_phrase.encode())
incoming = (await conn.read()).decode()
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
incoming = (await read_fixedint_prefixed(conn)).decode()
if incoming != self.key_phrase:
raise SecurityUpgradeFailure(
@ -48,8 +49,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
for an inbound connection (i.e. we are the initiator)
:return: secure connection object (that implements secure_conn_interface)
"""
await conn.write(self.key_phrase.encode())
incoming = (await conn.read()).decode()
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
incoming = (await read_fixedint_prefixed(conn)).decode()
# Force context switch, as this security transport is built for testing locally
# in a single event loop

View File

@ -188,11 +188,9 @@ class Mplex(IMuxedConn):
# loop in handle_incoming
timeout = 0.1
try:
header = await decode_uvarint_from_stream(self.conn.reader, timeout)
length = await decode_uvarint_from_stream(self.conn.reader, timeout)
message = await asyncio.wait_for(
self.conn.reader.read(length), timeout=timeout
)
header = await decode_uvarint_from_stream(self.conn, timeout)
length = await decode_uvarint_from_stream(self.conn, timeout)
message = await asyncio.wait_for(self.conn.read(length), timeout=timeout)
except asyncio.TimeoutError:
return None, None, None

View File

@ -1,6 +1,7 @@
import asyncio
from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union
from libp2p.network.connection.raw_connection_interface import IRawConnection
if TYPE_CHECKING:
from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401
from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401
@ -9,4 +10,4 @@ TProtocol = NewType("TProtocol", str)
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
StreamReader = Union["IMuxedStream", asyncio.StreamReader]
StreamReader = Union["IMuxedStream", IRawConnection]