rewrite tcp reader/writer interface

This commit is contained in:
Chih Cheng Liang 2019-11-19 14:01:12 +08:00 committed by mhchia
parent d4d345c3c7
commit 41ff884eef
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
9 changed files with 112 additions and 106 deletions

32
libp2p/io/trio.py Normal file
View File

@ -0,0 +1,32 @@
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
import logging
logger = logging.getLogger("libp2p.io.trio")
class TrioReadWriteCloser(ReadWriteCloser):
stream: SocketStream
def __init__(self, stream: SocketStream) -> None:
self.stream = stream
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async def read(self, n: int = -1) -> bytes:
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,41 +1,24 @@
import asyncio import trio
from libp2p.io.exceptions import IOException
from .exceptions import RawConnError from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection from .raw_connection_interface import IRawConnection
from libp2p.io.abc import ReadWriteCloser
class RawConnection(IRawConnection): class RawConnection(IRawConnection):
reader: asyncio.StreamReader read_write_closer: ReadWriteCloser
writer: asyncio.StreamWriter
is_initiator: bool is_initiator: bool
_drain_lock: asyncio.Lock def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
self.read_write_closer = read_write_closer
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
initiator: bool,
) -> None:
self.reader = reader
self.writer = writer
self.is_initiator = initiator self.is_initiator = initiator
self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks.""" """Raise `RawConnError` if the underlying connection breaks."""
try: try:
self.writer.write(data) await self.read_write_closer.write(data)
except ConnectionResetError as error: except IOException as error:
raise RawConnError(error)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
try:
await self.writer.drain()
except ConnectionResetError as error:
raise RawConnError(error) raise RawConnError(error)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
@ -46,10 +29,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks Raise `RawConnError` if the underlying connection breaks
""" """
try: try:
return await self.reader.read(n) return await self.read_write_closer.read(n)
except ConnectionResetError as error: except IOException as error:
raise RawConnError(error) raise RawConnError(error)
async def close(self) -> None: async def close(self) -> None:
self.writer.close() await self.read_write_closer.close()
await self.writer.wait_closed()

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore import PeerStoreError
@ -149,7 +150,7 @@ class Swarm(INetwork):
logger.debug("successfully opened a stream to peer %s", peer_id) logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream return net_stream
async def listen(self, *multiaddrs: Multiaddr) -> bool: async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool:
""" """
:param multiaddrs: one or many multiaddrs to start listening on :param multiaddrs: one or many multiaddrs to start listening on
:return: true if at least one success :return: true if at least one success
@ -167,15 +168,8 @@ class Swarm(INetwork):
if str(maddr) in self.listeners: if str(maddr) in self.listeners:
return True return True
async def conn_handler( async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
reader: asyncio.StreamReader, writer: asyncio.StreamWriter raw_conn = RawConnection(read_write_closer, False)
) -> None:
connection_info = writer.get_extra_info("peername")
# TODO make a proper multiaddr
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
logger.debug("inbound connection at %s", peer_addr)
# logger.debug("inbound connection request", peer_id)
raw_conn = RawConnection(reader, writer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn # the conn and then mux the conn
@ -185,14 +179,10 @@ class Swarm(INetwork):
raw_conn, ID(b""), False raw_conn, ID(b""), False
) )
except SecurityUpgradeFailure as error: except SecurityUpgradeFailure as error:
error_msg = "fail to upgrade security for peer at %s"
logger.debug(error_msg, peer_addr)
await raw_conn.close() await raw_conn.close()
raise SwarmException(error_msg % peer_addr) from error raise SwarmException() from error
peer_id = secured_conn.get_remote_peer() peer_id = secured_conn.get_remote_peer()
logger.debug("upgraded security for peer at %s", peer_addr)
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
@ -213,7 +203,7 @@ class Swarm(INetwork):
# Success # Success
listener = self.transport.create_listener(conn_handler) listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener self.listeners[str(maddr)] = listener
await listener.listen(maddr) await listener.listen(maddr, nursery)
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_listen(maddr) self.notify_listen(maddr)

View File

@ -123,27 +123,10 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end.""" """accepts a muxed stream opened by the other end."""
return await self._wait_until_shutting_down_or_closed( return await self.new_stream_queue.get()
self.new_stream_queue.get()
)
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -163,9 +146,7 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data) _bytes = header + encode_varint_prefixed(data)
return await self._wait_until_shutting_down_or_closed( return await self.write_to_stream(_bytes)
self.write_to_stream(_bytes)
)
async def write_to_stream(self, _bytes: bytes) -> int: async def write_to_stream(self, _bytes: bytes) -> int:
""" """
@ -226,9 +207,7 @@ class Mplex(IMuxedConn):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
""" """
channel_id, flag, message = await self._wait_until_shutting_down_or_closed( channel_id, flag, message = await self.read_message()
self.read_message()
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
@ -258,9 +237,7 @@ class Mplex(IMuxedConn):
f"received NewStream message for existing stream: {stream_id}" f"received NewStream message for existing stream: {stream_id}"
) )
mplex_stream = await self._initialize_stream(stream_id, message.decode()) mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed( await self.new_stream_queue.put(mplex_stream)
self.new_stream_queue.put(mplex_stream)
)
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock: async with self.streams_lock:
@ -274,9 +251,7 @@ class Mplex(IMuxedConn):
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return return
await self._wait_until_shutting_down_or_closed( await stream.incoming_data.put(message)
stream.incoming_data.put(message)
)
async def _handle_close(self, stream_id: StreamID) -> None: async def _handle_close(self, stream_id: StreamID) -> None:
async with self.streams_lock: async with self.streams_lock:

View File

@ -1,3 +1,4 @@
import trio
import asyncio import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -22,14 +23,14 @@ class MplexStream(IMuxedStream):
read_deadline: int read_deadline: int
write_deadline: int write_deadline: int
close_lock: asyncio.Lock close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation. # NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]" incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event event_local_closed: trio.Event
event_remote_closed: asyncio.Event event_remote_closed: trio.Event
event_reset: asyncio.Event event_reset: trio.Event
_buf: bytearray _buf: bytearray
@ -45,10 +46,10 @@ class MplexStream(IMuxedStream):
self.muxed_conn = muxed_conn self.muxed_conn = muxed_conn
self.read_deadline = None self.read_deadline = None
self.write_deadline = None self.write_deadline = None
self.event_local_closed = asyncio.Event() self.event_local_closed = trio.Event()
self.event_remote_closed = asyncio.Event() self.event_remote_closed = trio.Event()
self.event_reset = asyncio.Event() self.event_reset = trio.Event()
self.close_lock = asyncio.Lock() self.close_lock = trio.Lock()
self.incoming_data = asyncio.Queue() self.incoming_data = asyncio.Queue()
self._buf = bytearray() self._buf = bytearray()
@ -199,10 +200,11 @@ class MplexStream(IMuxedStream):
if self.is_initiator if self.is_initiator
else HeaderTags.ResetReceiver else HeaderTags.ResetReceiver
) )
asyncio.ensure_future( async with trio.open_nursery() as nursery:
self.muxed_conn.send_message(flag, None, self.stream_id) nursery.start_soon(
self.muxed_conn.send_message, flag, None, self.stream_id
) )
await asyncio.sleep(0) await trio.sleep(0)
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()

View File

@ -1,3 +1,4 @@
import trio
from typing import List, Sequence, Tuple from typing import List, Sequence, Tuple
import multiaddr import multiaddr
@ -37,12 +38,12 @@ async def connect(node1: IHost, node2: IHost) -> None:
async def set_up_nodes_by_transport_opt( async def set_up_nodes_by_transport_opt(
transport_opt_list: Sequence[Sequence[str]] transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery
) -> Tuple[BasicHost, ...]: ) -> Tuple[BasicHost, ...]:
nodes_list = [] nodes_list = []
for transport_opt in transport_opt_list: for transport_opt in transport_opt_list:
node = await new_node(transport_opt=transport_opt) node = new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery)
nodes_list.append(node) nodes_list.append(node)
return tuple(nodes_list) return tuple(nodes_list)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import trio
from socket import socket from socket import socket
from typing import List from typing import List
@ -10,6 +11,10 @@ from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler from libp2p.transport.typing import THandler
from libp2p.io.trio import TrioReadWriteCloser
import logging
logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener): class TCPListener(IListener):
@ -21,20 +26,38 @@ class TCPListener(IListener):
self.server = None self.server = None
self.handler = handler_function self.handler = handler_function
async def listen(self, maddr: Multiaddr) -> bool: async def listen(self, maddr: Multiaddr, nursery) -> bool:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.
:param maddr: maddr of peer :param maddr: maddr of peer
:return: return True if successful :return: return True if successful
""" """
self.server = await asyncio.start_server(
self.handler, async def serve_tcp(handler, port, host, task_status=None):
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream):
read_write_closer = TrioReadWriteCloser(stream)
await self.handler(read_write_closer)
listeners = await nursery.start(
serve_tcp,
*(
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"), maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"), ),
) )
socket = self.server.sockets[0] # self.server = await asyncio.start_server(
# self.handler,
# maddr.value_for_protocol("ip4"),
# maddr.value_for_protocol("tcp"),
# )
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket)) self.multiaddrs.append(_multiaddr_from_socket(socket))
logger.debug("Multiaddrs %s", self.multiaddrs)
return True return True
@ -69,12 +92,10 @@ class TCP(ITransport):
self.host = maddr.value_for_protocol("ip4") self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp")) self.port = int(maddr.value_for_protocol("tcp"))
try: stream = await trio.open_tcp_stream(self.host, self.port)
reader, writer = await asyncio.open_connection(self.host, self.port) read_write_closer = TrioReadWriteCloser(stream)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)
return RawConnection(reader, writer, True) return RawConnection(read_write_closer, True)
def create_listener(self, handler_function: THandler) -> TCPListener: def create_listener(self, handler_function: THandler) -> TCPListener:
""" """

View File

@ -1,11 +1,12 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable, Mapping, Type from typing import Awaitable, Callable, Mapping, Type
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.io.abc import ReadWriteCloser
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport] TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn] TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass] TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -1,3 +1,4 @@
import trio
import multiaddr import multiaddr
import pytest import pytest
@ -6,10 +7,10 @@ from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.utils import set_up_nodes_by_transport_opt from libp2p.tools.utils import set_up_nodes_by_transport_opt
@pytest.mark.asyncio @pytest.mark.trio
async def test_simple_messages(): async def test_simple_messages(nursery):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery)
async def stream_handler(stream): async def stream_handler(stream):
while True: while True:
@ -23,6 +24,7 @@ async def test_simple_messages():
# Associate the peer with local ip address (see default parameters of Libp2p()) # Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)] messages = ["hello" + str(x) for x in range(10)]