makes test_mplex_stream.py::test_mplex_stream_read_write work

This commit is contained in:
Chih Cheng Liang 2019-11-19 18:04:48 +08:00 committed by mhchia
parent c55ea0e5bb
commit a397ccdc04
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 70 additions and 122 deletions

View File

@ -1,11 +1,11 @@
import argparse
import asyncio
import trio_asyncio
import trio
import sys
import urllib.request
import multiaddr
import trio
import trio_asyncio
from libp2p import new_node
from libp2p.network.stream.net_stream_interface import INetStream
@ -42,7 +42,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
transport_opt = f"/ip4/{ip}/tcp/{port}"
host = new_node(transport_opt=[transport_opt])
await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) )
await trio_asyncio.run_asyncio(
host.get_network().listen, multiaddr.Multiaddr(transport_opt)
)
if not destination: # its the server
@ -70,7 +72,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID]))
stream = await trio_asyncio.run_asyncio(
host.new_stream, *(info.peer_id, [PROTOCOL_ID])
)
asyncio.ensure_future(read_data(stream))
asyncio.ensure_future(write_data(stream))
@ -119,5 +123,6 @@ def main() -> None:
trio_asyncio.run(run, *(args.port, args.destination, args.localhost))
if __name__ == "__main__":
main()

View File

@ -1,9 +1,10 @@
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
import logging
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio")

View File

@ -1,9 +1,10 @@
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection
from libp2p.io.abc import ReadWriteCloser
class RawConnection(IRawConnection):

View File

@ -3,6 +3,7 @@ import logging
from typing import Dict, List, Optional
from multiaddr import Multiaddr
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn
@ -69,7 +70,7 @@ class Swarm(INetwork):
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID) -> INetConn:
async def dial_peer(self, peer_id: ID, nursery) -> INetConn:
"""
dial_peer try to create a connection to peer_id.
@ -121,6 +122,7 @@ class Swarm(INetwork):
try:
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
muxed_conn.run(nursery)
except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id)
@ -135,7 +137,7 @@ class Swarm(INetwork):
return swarm_conn
async def new_stream(self, peer_id: ID) -> INetStream:
async def new_stream(self, peer_id: ID, nursery) -> INetStream:
"""
:param peer_id: peer_id of destination
:param protocol_id: protocol id
@ -144,7 +146,7 @@ class Swarm(INetwork):
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
swarm_conn = await self.dial_peer(peer_id)
swarm_conn = await self.dial_peer(peer_id, nursery)
net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
@ -183,11 +185,11 @@ class Swarm(INetwork):
raise SwarmException() from error
peer_id = secured_conn.get_remote_peer()
try:
muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id
)
muxed_conn.run(nursery)
except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id)
@ -198,6 +200,8 @@ class Swarm(INetwork):
await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id)
event = trio.Event()
await event.wait()
try:
# Success

View File

@ -3,6 +3,8 @@ import logging
from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple
import trio
from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.connection.exceptions import RawConnError
@ -41,8 +43,6 @@ class Mplex(IMuxedConn):
event_shutting_down: asyncio.Event
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
"""
create a new muxed connection.
@ -66,10 +66,8 @@ class Mplex(IMuxedConn):
self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self._tasks = []
# Kick off reading
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
def run(self, nursery):
nursery.start_soon(self.handle_incoming)
@property
def is_initiator(self) -> bool:
@ -123,7 +121,6 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream
async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end."""
return await self.new_stream_queue.get()
@ -169,7 +166,7 @@ class Mplex(IMuxedConn):
logger.debug("mplex unavailable while waiting for incoming: %s", e)
break
# Force context switch
await asyncio.sleep(0)
await trio.sleep(0)
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
@ -184,9 +181,7 @@ class Mplex(IMuxedConn):
# FIXME: No timeout is used in Go implementation.
try:
header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
message = await read_varint_prefixed_bytes(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
"failed to read messages correctly from the underlying connection"

View File

@ -1,8 +1,10 @@
import trio
import asyncio
from typing import TYPE_CHECKING
import trio
from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.utils import IQueue, TrioQueue
from .constants import HeaderTags
from .datastructures import StreamID
@ -26,7 +28,7 @@ class MplexStream(IMuxedStream):
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]"
incoming_data: IQueue[bytes]
event_local_closed: trio.Event
event_remote_closed: trio.Event
@ -50,69 +52,13 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.incoming_data = asyncio.Queue()
self.incoming_data = TrioQueue()
self._buf = bytearray()
@property
def is_initiator(self) -> bool:
return self.stream_id.is_initiator
async def _wait_for_data(self) -> None:
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
task_event_remote_closed = asyncio.ensure_future(
self.event_remote_closed.wait()
)
done, pending = await asyncio.wait( # type: ignore
[ # type: ignore
task_event_reset,
task_incoming_data_get,
task_event_remote_closed,
],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_event_reset in done:
if self.event_reset.is_set():
raise MplexStreamReset
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
if task_incoming_data_get in done:
data = task_incoming_data_get.result()
self._buf.extend(data)
return
if task_event_remote_closed in done:
if self.event_remote_closed.is_set():
raise MplexStreamEOF
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
# TODO: Handle timeout when deadline is used.
async def _read_until_eof(self) -> bytes:
while True:
try:
await self._wait_for_data()
except MplexStreamEOF:
break
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def read(self, n: int = -1) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
@ -128,20 +74,7 @@ class MplexStream(IMuxedStream):
)
if self.event_reset.is_set():
raise MplexStreamReset
if n == -1:
return await self._read_until_eof()
if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data()
# Now we are sure we have something to read.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n:
try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
return await self.incoming_data.get()
async def write(self, data: bytes) -> int:
"""

View File

@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
from .constants import MAX_READ_LEN
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None:
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
await swarm_0.dial_peer(peer_id, nursery)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
@ -43,7 +43,9 @@ async def set_up_nodes_by_transport_opt(
nodes_list = []
for transport_opt in transport_opt_list:
node = new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery)
await node.get_network().listen(
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
)
nodes_list.append(node)
return tuple(nodes_list)

View File

@ -1,18 +1,18 @@
import asyncio
import trio
import logging
from socket import socket
from typing import List
from multiaddr import Multiaddr
import trio
from libp2p.io.trio import TrioReadWriteCloser
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
from libp2p.io.trio import TrioReadWriteCloser
import logging
logger = logging.getLogger("libp2p.transport.tcp")
@ -44,11 +44,9 @@ class TCPListener(IListener):
listeners = await nursery.start(
serve_tcp,
*(
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
),
)
# self.server = await asyncio.start_server(
# self.handler,
@ -57,7 +55,6 @@ class TCPListener(IListener):
# )
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket))
logger.debug("Multiaddrs %s", self.multiaddrs)
return True

View File

@ -1,10 +1,9 @@
from typing import Awaitable, Callable, Mapping, Type
from libp2p.io.abc import ReadWriteCloser
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
from libp2p.io.abc import ReadWriteCloser
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]

View File

@ -1,14 +1,14 @@
import itertools
import math
from typing import Generic, TypeVar
import trio
from libp2p.exceptions import ParseError
from libp2p.io.abc import Reader
from .io.utils import read_exactly
from typing import Generic, TypeVar
import trio
# Unsigned LEB128(varint codec)
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py

View File

@ -1,6 +1,6 @@
import trio
import multiaddr
import pytest
import trio
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.constants import MAX_READ_LEN
@ -24,11 +24,11 @@ async def test_simple_messages(nursery):
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()

View File

@ -1,20 +1,31 @@
import asyncio
import pytest
import trio
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR
from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm
DATA = b"data_123"
@pytest.mark.asyncio
async def test_mplex_stream_read_write(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
@pytest.mark.trio
async def test_mplex_stream_read_write(nursery):
swarm0, swarm1 = SwarmFactory(), SwarmFactory()
await swarm0.listen(LISTEN_MADDR, nursery=nursery)
await swarm1.listen(LISTEN_MADDR, nursery=nursery)
await connect_swarm(swarm0, swarm1, nursery)
conn_0 = swarm0.connections[swarm1.get_peer_id()]
conn_1 = swarm1.connections[swarm0.get_peer_id()]
stream_0 = await conn_0.muxed_conn.open_stream()
await trio.sleep(1)
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA

View File

@ -1,5 +1,6 @@
import trio
import pytest
import trio
from libp2p.utils import TrioQueue
@ -16,4 +17,3 @@ async def test_trio_queue():
result = await nursery.start(queue_get)
assert result == 123