rewrite tcp reader/writer interface

This commit is contained in:
Chih Cheng Liang 2019-11-19 14:01:12 +08:00 committed by mhchia
parent d4d345c3c7
commit 41ff884eef
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
9 changed files with 112 additions and 106 deletions

32
libp2p/io/trio.py Normal file
View File

@ -0,0 +1,32 @@
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
import logging
logger = logging.getLogger("libp2p.io.trio")
class TrioReadWriteCloser(ReadWriteCloser):
stream: SocketStream
def __init__(self, stream: SocketStream) -> None:
self.stream = stream
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async def read(self, n: int = -1) -> bytes:
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,42 +1,25 @@
import asyncio
import trio
from libp2p.io.exceptions import IOException
from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection
from libp2p.io.abc import ReadWriteCloser
class RawConnection(IRawConnection):
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
read_write_closer: ReadWriteCloser
is_initiator: bool
_drain_lock: asyncio.Lock
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
initiator: bool,
) -> None:
self.reader = reader
self.writer = writer
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
self.read_write_closer = read_write_closer
self.is_initiator = initiator
self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
self.writer.write(data)
except ConnectionResetError as error:
await self.read_write_closer.write(data)
except IOException as error:
raise RawConnError(error)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
try:
await self.writer.drain()
except ConnectionResetError as error:
raise RawConnError(error)
async def read(self, n: int = -1) -> bytes:
"""
@ -46,10 +29,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks
"""
try:
return await self.reader.read(n)
except ConnectionResetError as error:
return await self.read_write_closer.read(n)
except IOException as error:
raise RawConnError(error)
async def close(self) -> None:
self.writer.close()
await self.writer.wait_closed()
await self.read_write_closer.close()

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from multiaddr import Multiaddr
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError
@ -149,7 +150,7 @@ class Swarm(INetwork):
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
async def listen(self, *multiaddrs: Multiaddr) -> bool:
async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool:
"""
:param multiaddrs: one or many multiaddrs to start listening on
:return: true if at least one success
@ -167,15 +168,8 @@ class Swarm(INetwork):
if str(maddr) in self.listeners:
return True
async def conn_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
connection_info = writer.get_extra_info("peername")
# TODO make a proper multiaddr
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
logger.debug("inbound connection at %s", peer_addr)
# logger.debug("inbound connection request", peer_id)
raw_conn = RawConnection(reader, writer, False)
async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
raw_conn = RawConnection(read_write_closer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn
@ -185,14 +179,10 @@ class Swarm(INetwork):
raw_conn, ID(b""), False
)
except SecurityUpgradeFailure as error:
error_msg = "fail to upgrade security for peer at %s"
logger.debug(error_msg, peer_addr)
await raw_conn.close()
raise SwarmException(error_msg % peer_addr) from error
raise SwarmException() from error
peer_id = secured_conn.get_remote_peer()
logger.debug("upgraded security for peer at %s", peer_addr)
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
try:
muxed_conn = await self.upgrader.upgrade_connection(
@ -213,7 +203,7 @@ class Swarm(INetwork):
# Success
listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener
await listener.listen(maddr)
await listener.listen(maddr, nursery)
# Call notifiers since event occurred
self.notify_listen(maddr)

View File

@ -123,27 +123,10 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end."""
return await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.get()
)
return await self.new_stream_queue.get()
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -163,9 +146,7 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data)
return await self._wait_until_shutting_down_or_closed(
self.write_to_stream(_bytes)
)
return await self.write_to_stream(_bytes)
async def write_to_stream(self, _bytes: bytes) -> int:
"""
@ -226,9 +207,7 @@ class Mplex(IMuxedConn):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
self.read_message()
)
channel_id, flag, message = await self.read_message()
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value:
@ -258,9 +237,7 @@ class Mplex(IMuxedConn):
f"received NewStream message for existing stream: {stream_id}"
)
mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.put(mplex_stream)
)
await self.new_stream_queue.put(mplex_stream)
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
@ -274,9 +251,7 @@ class Mplex(IMuxedConn):
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return
await self._wait_until_shutting_down_or_closed(
stream.incoming_data.put(message)
)
await stream.incoming_data.put(message)
async def _handle_close(self, stream_id: StreamID) -> None:
async with self.streams_lock:

View File

@ -1,3 +1,4 @@
import trio
import asyncio
from typing import TYPE_CHECKING
@ -22,14 +23,14 @@ class MplexStream(IMuxedStream):
read_deadline: int
write_deadline: int
close_lock: asyncio.Lock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event
event_remote_closed: asyncio.Event
event_reset: asyncio.Event
event_local_closed: trio.Event
event_remote_closed: trio.Event
event_reset: trio.Event
_buf: bytearray
@ -45,10 +46,10 @@ class MplexStream(IMuxedStream):
self.muxed_conn = muxed_conn
self.read_deadline = None
self.write_deadline = None
self.event_local_closed = asyncio.Event()
self.event_remote_closed = asyncio.Event()
self.event_reset = asyncio.Event()
self.close_lock = asyncio.Lock()
self.event_local_closed = trio.Event()
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.incoming_data = asyncio.Queue()
self._buf = bytearray()
@ -199,10 +200,11 @@ class MplexStream(IMuxedStream):
if self.is_initiator
else HeaderTags.ResetReceiver
)
asyncio.ensure_future(
self.muxed_conn.send_message(flag, None, self.stream_id)
)
await asyncio.sleep(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(
self.muxed_conn.send_message, flag, None, self.stream_id
)
await trio.sleep(0)
self.event_local_closed.set()
self.event_remote_closed.set()

View File

@ -1,3 +1,4 @@
import trio
from typing import List, Sequence, Tuple
import multiaddr
@ -37,12 +38,12 @@ async def connect(node1: IHost, node2: IHost) -> None:
async def set_up_nodes_by_transport_opt(
transport_opt_list: Sequence[Sequence[str]]
transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt in transport_opt_list:
node = await new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
node = new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery)
nodes_list.append(node)
return tuple(nodes_list)

View File

@ -1,4 +1,5 @@
import asyncio
import trio
from socket import socket
from typing import List
@ -10,6 +11,10 @@ from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
from libp2p.io.trio import TrioReadWriteCloser
import logging
logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener):
@ -21,20 +26,38 @@ class TCPListener(IListener):
self.server = None
self.handler = handler_function
async def listen(self, maddr: Multiaddr) -> bool:
async def listen(self, maddr: Multiaddr, nursery) -> bool:
"""
put listener in listening mode and wait for incoming connections.
:param maddr: maddr of peer
:return: return True if successful
"""
self.server = await asyncio.start_server(
self.handler,
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
async def serve_tcp(handler, port, host, task_status=None):
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream):
read_write_closer = TrioReadWriteCloser(stream)
await self.handler(read_write_closer)
listeners = await nursery.start(
serve_tcp,
*(
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
),
)
socket = self.server.sockets[0]
# self.server = await asyncio.start_server(
# self.handler,
# maddr.value_for_protocol("ip4"),
# maddr.value_for_protocol("tcp"),
# )
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket))
logger.debug("Multiaddrs %s", self.multiaddrs)
return True
@ -69,12 +92,10 @@ class TCP(ITransport):
self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp"))
try:
reader, writer = await asyncio.open_connection(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)
stream = await trio.open_tcp_stream(self.host, self.port)
read_write_closer = TrioReadWriteCloser(stream)
return RawConnection(reader, writer, True)
return RawConnection(read_write_closer, True)
def create_listener(self, handler_function: THandler) -> TCPListener:
"""

View File

@ -1,11 +1,12 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable, Mapping, Type
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
from libp2p.io.abc import ReadWriteCloser
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -1,3 +1,4 @@
import trio
import multiaddr
import pytest
@ -6,10 +7,10 @@ from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.utils import set_up_nodes_by_transport_opt
@pytest.mark.asyncio
async def test_simple_messages():
@pytest.mark.trio
async def test_simple_messages(nursery):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list, nursery)
async def stream_handler(stream):
while True:
@ -23,6 +24,7 @@ async def test_simple_messages():
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]