Use RawConnection.read
Instead of accessing its reader and writer directly. TODO: considering add `ReaderWriterCloser` interface and let connection and stream inherit from it.
This commit is contained in:
parent
0b466ddc86
commit
ef476e555b
@ -39,9 +39,12 @@ class RawConnection(IRawConnection):
|
|||||||
async with self._drain_lock:
|
async with self._drain_lock:
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
async def read(self) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
line = await self.reader.readline()
|
"""
|
||||||
return line.rstrip(b"\n")
|
Read up to ``n`` bytes from the underlying stream.
|
||||||
|
This call is delegated directly to the underlying ``self.reader``.
|
||||||
|
"""
|
||||||
|
return await self.reader.read(n)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
class IRawConnection(ABC):
|
class IRawConnection(ABC):
|
||||||
@ -9,17 +8,12 @@ class IRawConnection(ABC):
|
|||||||
|
|
||||||
initiator: bool
|
initiator: bool
|
||||||
|
|
||||||
# TODO: reader and writer shouldn't be exposed.
|
|
||||||
# Need better API for the consumers
|
|
||||||
reader: asyncio.StreamReader
|
|
||||||
writer: asyncio.StreamWriter
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def read(self) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -16,7 +16,7 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
|
|||||||
await self.conn.write(msg_bytes)
|
await self.conn.write(msg_bytes)
|
||||||
|
|
||||||
async def read(self) -> str:
|
async def read(self) -> str:
|
||||||
data = await read_delim(self.conn.reader)
|
data = await read_delim(self.conn)
|
||||||
return data.decode()
|
return data.decode()
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,8 +23,6 @@ class BaseSession(ISecureConn):
|
|||||||
self.remote_permanent_pubkey = None
|
self.remote_permanent_pubkey = None
|
||||||
|
|
||||||
self.initiator = self.conn.initiator
|
self.initiator = self.conn.initiator
|
||||||
self.writer = self.conn.writer
|
|
||||||
self.reader = self.conn.reader
|
|
||||||
|
|
||||||
# TODO clean up how this is passed around?
|
# TODO clean up how this is passed around?
|
||||||
def next_stream_id(self) -> int:
|
def next_stream_id(self) -> int:
|
||||||
@ -33,8 +31,8 @@ class BaseSession(ISecureConn):
|
|||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
await self.conn.write(data)
|
await self.conn.write(data)
|
||||||
|
|
||||||
async def read(self) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
return await self.conn.read()
|
return await self.conn.read(n)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
@ -23,7 +23,7 @@ class InsecureSession(BaseSession):
|
|||||||
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
||||||
await self.write(encoded_msg_bytes)
|
await self.write(encoded_msg_bytes)
|
||||||
|
|
||||||
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
|
msg_bytes_other_side = await read_fixedint_prefixed(self.conn)
|
||||||
msg_other_side = plaintext_pb2.Exchange()
|
msg_other_side = plaintext_pb2.Exchange()
|
||||||
msg_other_side.ParseFromString(msg_bytes_other_side)
|
msg_other_side.ParseFromString(msg_bytes_other_side)
|
||||||
|
|
||||||
|
@ -87,10 +87,6 @@ class SecurityMultistream(ABC):
|
|||||||
:param initiator: true if we are the initiator, false otherwise
|
:param initiator: true if we are the initiator, false otherwise
|
||||||
:return: selected secure transport
|
:return: selected secure transport
|
||||||
"""
|
"""
|
||||||
# TODO: Is conn acceptable to multiselect/multiselect_client
|
|
||||||
# instead of stream? In go repo, they pass in a raw conn
|
|
||||||
# (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go)
|
|
||||||
|
|
||||||
protocol: TProtocol
|
protocol: TProtocol
|
||||||
communicator = RawConnectionCommunicator(conn)
|
communicator = RawConnectionCommunicator(conn)
|
||||||
if initiator:
|
if initiator:
|
||||||
|
@ -7,6 +7,7 @@ from libp2p.security.base_transport import BaseSecureTransport
|
|||||||
from libp2p.security.insecure.transport import InsecureSession
|
from libp2p.security.insecure.transport import InsecureSession
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.transport.exceptions import SecurityUpgradeFailure
|
from libp2p.transport.exceptions import SecurityUpgradeFailure
|
||||||
|
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
|
||||||
|
|
||||||
|
|
||||||
class SimpleSecurityTransport(BaseSecureTransport):
|
class SimpleSecurityTransport(BaseSecureTransport):
|
||||||
@ -22,8 +23,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
|
|||||||
for an inbound connection (i.e. we are not the initiator)
|
for an inbound connection (i.e. we are not the initiator)
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
await conn.write(self.key_phrase.encode())
|
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||||
incoming = (await conn.read()).decode()
|
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||||
|
|
||||||
if incoming != self.key_phrase:
|
if incoming != self.key_phrase:
|
||||||
raise SecurityUpgradeFailure(
|
raise SecurityUpgradeFailure(
|
||||||
@ -48,8 +49,8 @@ class SimpleSecurityTransport(BaseSecureTransport):
|
|||||||
for an inbound connection (i.e. we are the initiator)
|
for an inbound connection (i.e. we are the initiator)
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
await conn.write(self.key_phrase.encode())
|
await conn.write(encode_fixedint_prefixed(self.key_phrase.encode()))
|
||||||
incoming = (await conn.read()).decode()
|
incoming = (await read_fixedint_prefixed(conn)).decode()
|
||||||
|
|
||||||
# Force context switch, as this security transport is built for testing locally
|
# Force context switch, as this security transport is built for testing locally
|
||||||
# in a single event loop
|
# in a single event loop
|
||||||
|
@ -188,11 +188,9 @@ class Mplex(IMuxedConn):
|
|||||||
# loop in handle_incoming
|
# loop in handle_incoming
|
||||||
timeout = 0.1
|
timeout = 0.1
|
||||||
try:
|
try:
|
||||||
header = await decode_uvarint_from_stream(self.conn.reader, timeout)
|
header = await decode_uvarint_from_stream(self.conn, timeout)
|
||||||
length = await decode_uvarint_from_stream(self.conn.reader, timeout)
|
length = await decode_uvarint_from_stream(self.conn, timeout)
|
||||||
message = await asyncio.wait_for(
|
message = await asyncio.wait_for(self.conn.read(length), timeout=timeout)
|
||||||
self.conn.reader.read(length), timeout=timeout
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union
|
from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union
|
||||||
|
|
||||||
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401
|
from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401
|
||||||
from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401
|
from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401
|
||||||
@ -9,4 +10,4 @@ TProtocol = NewType("TProtocol", str)
|
|||||||
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
|
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
StreamReader = Union["IMuxedStream", asyncio.StreamReader]
|
StreamReader = Union["IMuxedStream", IRawConnection]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user