rewrite tcp reader/writer interface
This commit is contained in:
parent
d4d345c3c7
commit
41ff884eef
32
libp2p/io/trio.py
Normal file
32
libp2p/io/trio.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import trio
|
||||
from trio import SocketStream
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.io.trio")
|
||||
|
||||
|
||||
class TrioReadWriteCloser(ReadWriteCloser):
|
||||
stream: SocketStream
|
||||
|
||||
def __init__(self, stream: SocketStream) -> None:
|
||||
self.stream = stream
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Raise `RawConnError` if the underlying connection breaks."""
|
||||
try:
|
||||
await self.stream.send_all(data)
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
|
||||
raise IOException(error)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
max_bytes = n if n != -1 else None
|
||||
try:
|
||||
return await self.stream.receive_some(max_bytes)
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
|
||||
raise IOException(error)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stream.aclose()
|
|
@ -1,42 +1,25 @@
|
|||
import asyncio
|
||||
import trio
|
||||
|
||||
from libp2p.io.exceptions import IOException
|
||||
from .exceptions import RawConnError
|
||||
from .raw_connection_interface import IRawConnection
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
|
||||
|
||||
class RawConnection(IRawConnection):
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
read_write_closer: ReadWriteCloser
|
||||
is_initiator: bool
|
||||
|
||||
_drain_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
initiator: bool,
|
||||
) -> None:
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
|
||||
self.read_write_closer = read_write_closer
|
||||
self.is_initiator = initiator
|
||||
|
||||
self._drain_lock = asyncio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Raise `RawConnError` if the underlying connection breaks."""
|
||||
try:
|
||||
self.writer.write(data)
|
||||
except ConnectionResetError as error:
|
||||
await self.read_write_closer.write(data)
|
||||
except IOException as error:
|
||||
raise RawConnError(error)
|
||||
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||
# https://bugs.python.org/issue29930
|
||||
async with self._drain_lock:
|
||||
try:
|
||||
await self.writer.drain()
|
||||
except ConnectionResetError as error:
|
||||
raise RawConnError(error)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
|
@ -46,10 +29,9 @@ class RawConnection(IRawConnection):
|
|||
Raise `RawConnError` if the underlying connection breaks
|
||||
"""
|
||||
try:
|
||||
return await self.reader.read(n)
|
||||
except ConnectionResetError as error:
|
||||
return await self.read_write_closer.read(n)
|
||||
except IOException as error:
|
||||
raise RawConnError(error)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
await self.read_write_closer.close()
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStoreError
|
||||
|
@ -149,7 +150,7 @@ class Swarm(INetwork):
|
|||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool:
|
||||
"""
|
||||
:param multiaddrs: one or many multiaddrs to start listening on
|
||||
:return: true if at least one success
|
||||
|
@ -167,15 +168,8 @@ class Swarm(INetwork):
|
|||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
|
||||
async def conn_handler(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
connection_info = writer.get_extra_info("peername")
|
||||
# TODO make a proper multiaddr
|
||||
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
|
||||
logger.debug("inbound connection at %s", peer_addr)
|
||||
# logger.debug("inbound connection request", peer_id)
|
||||
raw_conn = RawConnection(reader, writer, False)
|
||||
async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
|
@ -185,14 +179,10 @@ class Swarm(INetwork):
|
|||
raw_conn, ID(b""), False
|
||||
)
|
||||
except SecurityUpgradeFailure as error:
|
||||
error_msg = "fail to upgrade security for peer at %s"
|
||||
logger.debug(error_msg, peer_addr)
|
||||
await raw_conn.close()
|
||||
raise SwarmException(error_msg % peer_addr) from error
|
||||
raise SwarmException() from error
|
||||
peer_id = secured_conn.get_remote_peer()
|
||||
|
||||
logger.debug("upgraded security for peer at %s", peer_addr)
|
||||
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
|
||||
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
|
@ -213,7 +203,7 @@ class Swarm(INetwork):
|
|||
# Success
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
self.listeners[str(maddr)] = listener
|
||||
await listener.listen(maddr)
|
||||
await listener.listen(maddr, nursery)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
self.notify_listen(maddr)
|
||||
|
|
|
@ -123,27 +123,10 @@ class Mplex(IMuxedConn):
|
|||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||
return stream
|
||||
|
||||
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
|
||||
task_coro = asyncio.ensure_future(coro)
|
||||
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
|
||||
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
[task_coro, task_wait_closed, task_wait_shutting_down],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
if task_wait_closed in done:
|
||||
raise MplexUnavailable("Mplex is closed")
|
||||
if task_wait_shutting_down in done:
|
||||
raise MplexUnavailable("Mplex is shutting down")
|
||||
return task_coro.result()
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
"""accepts a muxed stream opened by the other end."""
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.get()
|
||||
)
|
||||
return await self.new_stream_queue.get()
|
||||
|
||||
async def send_message(
|
||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||
|
@ -163,9 +146,7 @@ class Mplex(IMuxedConn):
|
|||
|
||||
_bytes = header + encode_varint_prefixed(data)
|
||||
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.write_to_stream(_bytes)
|
||||
)
|
||||
return await self.write_to_stream(_bytes)
|
||||
|
||||
async def write_to_stream(self, _bytes: bytes) -> int:
|
||||
"""
|
||||
|
@ -226,9 +207,7 @@ class Mplex(IMuxedConn):
|
|||
|
||||
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||
"""
|
||||
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
|
||||
self.read_message()
|
||||
)
|
||||
channel_id, flag, message = await self.read_message()
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
|
@ -258,9 +237,7 @@ class Mplex(IMuxedConn):
|
|||
f"received NewStream message for existing stream: {stream_id}"
|
||||
)
|
||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.put(mplex_stream)
|
||||
)
|
||||
await self.new_stream_queue.put(mplex_stream)
|
||||
|
||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||
async with self.streams_lock:
|
||||
|
@ -274,9 +251,7 @@ class Mplex(IMuxedConn):
|
|||
if stream.event_remote_closed.is_set():
|
||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||
return
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
stream.incoming_data.put(message)
|
||||
)
|
||||
await stream.incoming_data.put(message)
|
||||
|
||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import trio
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -22,14 +23,14 @@ class MplexStream(IMuxedStream):
|
|||
read_deadline: int
|
||||
write_deadline: int
|
||||
|
||||
close_lock: asyncio.Lock
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
incoming_data: "asyncio.Queue[bytes]"
|
||||
|
||||
event_local_closed: asyncio.Event
|
||||
event_remote_closed: asyncio.Event
|
||||
event_reset: asyncio.Event
|
||||
event_local_closed: trio.Event
|
||||
event_remote_closed: trio.Event
|
||||
event_reset: trio.Event
|
||||
|
||||
_buf: bytearray
|
||||
|
||||
|
@ -45,10 +46,10 @@ class MplexStream(IMuxedStream):
|
|||
self.muxed_conn = muxed_conn
|
||||
self.read_deadline = None
|
||||
self.write_deadline = None
|
||||
self.event_local_closed = asyncio.Event()
|
||||
self.event_remote_closed = asyncio.Event()
|
||||
self.event_reset = asyncio.Event()
|
||||
self.close_lock = asyncio.Lock()
|
||||
self.event_local_closed = trio.Event()
|
||||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.incoming_data = asyncio.Queue()
|
||||
self._buf = bytearray()
|
||||
|
||||
|
@ -199,10 +200,11 @@ class MplexStream(IMuxedStream):
|
|||
if self.is_initiator
|
||||
else HeaderTags.ResetReceiver
|
||||
)
|
||||
asyncio.ensure_future(
|
||||
self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
self.muxed_conn.send_message, flag, None, self.stream_id
|
||||
)
|
||||
await trio.sleep(0)
|
||||
|
||||
self.event_local_closed.set()
|
||||
self.event_remote_closed.set()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import trio
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
import multiaddr
|
||||
|
@ -37,12 +38,12 @@ async def connect(node1: IHost, node2: IHost) -> None:
|
|||
|
||||
|
||||
async def set_up_nodes_by_transport_opt(
|
||||
transport_opt_list: Sequence[Sequence[str]]
|
||||
transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery
|
||||
) -> Tuple[BasicHost, ...]:
|
||||
nodes_list = []
|
||||
for transport_opt in transport_opt_list:
|
||||
node = await new_node(transport_opt=transport_opt)
|
||||
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
|
||||
node = new_node(transport_opt=transport_opt)
|
||||
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery)
|
||||
nodes_list.append(node)
|
||||
return tuple(nodes_list)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import trio
|
||||
from socket import socket
|
||||
from typing import List
|
||||
|
||||
|
@ -10,6 +11,10 @@ from libp2p.transport.exceptions import OpenConnectionError
|
|||
from libp2p.transport.listener_interface import IListener
|
||||
from libp2p.transport.transport_interface import ITransport
|
||||
from libp2p.transport.typing import THandler
|
||||
from libp2p.io.trio import TrioReadWriteCloser
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.tcp")
|
||||
|
||||
|
||||
class TCPListener(IListener):
|
||||
|
@ -21,20 +26,38 @@ class TCPListener(IListener):
|
|||
self.server = None
|
||||
self.handler = handler_function
|
||||
|
||||
async def listen(self, maddr: Multiaddr) -> bool:
|
||||
async def listen(self, maddr: Multiaddr, nursery) -> bool:
|
||||
"""
|
||||
put listener in listening mode and wait for incoming connections.
|
||||
|
||||
:param maddr: maddr of peer
|
||||
:return: return True if successful
|
||||
"""
|
||||
self.server = await asyncio.start_server(
|
||||
self.handler,
|
||||
maddr.value_for_protocol("ip4"),
|
||||
maddr.value_for_protocol("tcp"),
|
||||
|
||||
async def serve_tcp(handler, port, host, task_status=None):
|
||||
logger.debug("serve_tcp %s %s", host, port)
|
||||
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
|
||||
|
||||
async def handler(stream):
|
||||
read_write_closer = TrioReadWriteCloser(stream)
|
||||
await self.handler(read_write_closer)
|
||||
|
||||
listeners = await nursery.start(
|
||||
serve_tcp,
|
||||
*(
|
||||
handler,
|
||||
int(maddr.value_for_protocol("tcp")),
|
||||
maddr.value_for_protocol("ip4"),
|
||||
),
|
||||
)
|
||||
socket = self.server.sockets[0]
|
||||
# self.server = await asyncio.start_server(
|
||||
# self.handler,
|
||||
# maddr.value_for_protocol("ip4"),
|
||||
# maddr.value_for_protocol("tcp"),
|
||||
# )
|
||||
socket = listeners[0].socket
|
||||
self.multiaddrs.append(_multiaddr_from_socket(socket))
|
||||
logger.debug("Multiaddrs %s", self.multiaddrs)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -69,12 +92,10 @@ class TCP(ITransport):
|
|||
self.host = maddr.value_for_protocol("ip4")
|
||||
self.port = int(maddr.value_for_protocol("tcp"))
|
||||
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.host, self.port)
|
||||
except (ConnectionAbortedError, ConnectionRefusedError) as error:
|
||||
raise OpenConnectionError(error)
|
||||
stream = await trio.open_tcp_stream(self.host, self.port)
|
||||
read_write_closer = TrioReadWriteCloser(stream)
|
||||
|
||||
return RawConnection(reader, writer, True)
|
||||
return RawConnection(read_write_closer, True)
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> TCPListener:
|
||||
"""
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from asyncio import StreamReader, StreamWriter
|
||||
|
||||
from typing import Awaitable, Callable, Mapping, Type
|
||||
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
|
||||
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
|
||||
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
|
||||
TMuxerClass = Type[IMuxedConn]
|
||||
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import trio
|
||||
import multiaddr
|
||||
import pytest
|
||||
|
||||
|
@ -6,10 +7,10 @@ from libp2p.tools.constants import MAX_READ_LEN
|
|||
from libp2p.tools.utils import set_up_nodes_by_transport_opt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_messages():
|
||||
@pytest.mark.trio
|
||||
async def test_simple_messages(nursery):
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery)
|
||||
|
||||
async def stream_handler(stream):
|
||||
while True:
|
||||
|
@ -23,6 +24,7 @@ async def test_simple_messages():
|
|||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
|
||||
|
||||
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
|
|
Loading…
Reference in New Issue
Block a user