makes test_mplex_stream.py::test_mplex_stream_read_write work
This commit is contained in:
parent
c55ea0e5bb
commit
a397ccdc04
|
@ -1,11 +1,11 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import trio_asyncio
|
||||
import trio
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
import trio_asyncio
|
||||
|
||||
from libp2p import new_node
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
|
@ -42,7 +42,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
|
|||
transport_opt = f"/ip4/{ip}/tcp/{port}"
|
||||
host = new_node(transport_opt=[transport_opt])
|
||||
|
||||
await trio_asyncio.run_asyncio(host.get_network().listen,multiaddr.Multiaddr(transport_opt) )
|
||||
await trio_asyncio.run_asyncio(
|
||||
host.get_network().listen, multiaddr.Multiaddr(transport_opt)
|
||||
)
|
||||
|
||||
if not destination: # its the server
|
||||
|
||||
|
@ -70,7 +72,9 @@ async def run(port: int, destination: str, localhost: bool) -> None:
|
|||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await trio_asyncio.run_asyncio(host.new_stream, *(info.peer_id, [PROTOCOL_ID]))
|
||||
stream = await trio_asyncio.run_asyncio(
|
||||
host.new_stream, *(info.peer_id, [PROTOCOL_ID])
|
||||
)
|
||||
|
||||
asyncio.ensure_future(read_data(stream))
|
||||
asyncio.ensure_future(write_data(stream))
|
||||
|
@ -119,5 +123,6 @@ def main() -> None:
|
|||
|
||||
trio_asyncio.run(run, *(args.port, args.destination, args.localhost))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import trio
|
||||
from trio import SocketStream
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
import logging
|
||||
|
||||
import trio
|
||||
from trio import SocketStream
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger("libp2p.io.trio")
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
from .exceptions import RawConnError
|
||||
from .raw_connection_interface import IRawConnection
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
|
||||
|
||||
class RawConnection(IRawConnection):
|
||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
|
@ -69,7 +70,7 @@ class Swarm(INetwork):
|
|||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
async def dial_peer(self, peer_id: ID, nursery) -> INetConn:
|
||||
"""
|
||||
dial_peer try to create a connection to peer_id.
|
||||
|
||||
|
@ -121,6 +122,7 @@ class Swarm(INetwork):
|
|||
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
|
||||
muxed_conn.run(nursery)
|
||||
except MuxerUpgradeFailure as error:
|
||||
error_msg = "fail to upgrade mux for peer %s"
|
||||
logger.debug(error_msg, peer_id)
|
||||
|
@ -135,7 +137,7 @@ class Swarm(INetwork):
|
|||
|
||||
return swarm_conn
|
||||
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
async def new_stream(self, peer_id: ID, nursery) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id of destination
|
||||
:param protocol_id: protocol id
|
||||
|
@ -144,7 +146,7 @@ class Swarm(INetwork):
|
|||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
swarm_conn = await self.dial_peer(peer_id, nursery)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
|
@ -183,11 +185,11 @@ class Swarm(INetwork):
|
|||
raise SwarmException() from error
|
||||
peer_id = secured_conn.get_remote_peer()
|
||||
|
||||
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
secured_conn, peer_id
|
||||
)
|
||||
muxed_conn.run(nursery)
|
||||
except MuxerUpgradeFailure as error:
|
||||
error_msg = "fail to upgrade mux for peer %s"
|
||||
logger.debug(error_msg, peer_id)
|
||||
|
@ -198,6 +200,8 @@ class Swarm(INetwork):
|
|||
await self.add_conn(muxed_conn)
|
||||
|
||||
logger.debug("successfully opened connection to peer %s", peer_id)
|
||||
event = trio.Event()
|
||||
await event.wait()
|
||||
|
||||
try:
|
||||
# Success
|
||||
|
|
|
@ -3,6 +3,8 @@ import logging
|
|||
from typing import Any # noqa: F401
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
from libp2p.network.connection.exceptions import RawConnError
|
||||
|
@ -41,8 +43,6 @@ class Mplex(IMuxedConn):
|
|||
event_shutting_down: asyncio.Event
|
||||
event_closed: asyncio.Event
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
|
||||
"""
|
||||
create a new muxed connection.
|
||||
|
@ -66,10 +66,8 @@ class Mplex(IMuxedConn):
|
|||
self.event_shutting_down = asyncio.Event()
|
||||
self.event_closed = asyncio.Event()
|
||||
|
||||
self._tasks = []
|
||||
|
||||
# Kick off reading
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
|
||||
def run(self, nursery):
|
||||
nursery.start_soon(self.handle_incoming)
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
|
@ -123,7 +121,6 @@ class Mplex(IMuxedConn):
|
|||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||
return stream
|
||||
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
"""accepts a muxed stream opened by the other end."""
|
||||
return await self.new_stream_queue.get()
|
||||
|
@ -169,7 +166,7 @@ class Mplex(IMuxedConn):
|
|||
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||
break
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
await trio.sleep(0)
|
||||
# If we enter here, it means this connection is shutting down.
|
||||
# We should clean things up.
|
||||
await self._cleanup()
|
||||
|
@ -184,9 +181,7 @@ class Mplex(IMuxedConn):
|
|||
# FIXME: No timeout is used in Go implementation.
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||
message = await asyncio.wait_for(
|
||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||
)
|
||||
message = await read_varint_prefixed_bytes(self.secured_conn)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read messages correctly from the underlying connection"
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import trio
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.stream_muxer.abc import IMuxedStream
|
||||
from libp2p.utils import IQueue, TrioQueue
|
||||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
|
@ -26,7 +28,7 @@ class MplexStream(IMuxedStream):
|
|||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
incoming_data: "asyncio.Queue[bytes]"
|
||||
incoming_data: IQueue[bytes]
|
||||
|
||||
event_local_closed: trio.Event
|
||||
event_remote_closed: trio.Event
|
||||
|
@ -50,69 +52,13 @@ class MplexStream(IMuxedStream):
|
|||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.incoming_data = asyncio.Queue()
|
||||
self.incoming_data = TrioQueue()
|
||||
self._buf = bytearray()
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
return self.stream_id.is_initiator
|
||||
|
||||
async def _wait_for_data(self) -> None:
|
||||
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
|
||||
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
|
||||
task_event_remote_closed = asyncio.ensure_future(
|
||||
self.event_remote_closed.wait()
|
||||
)
|
||||
done, pending = await asyncio.wait( # type: ignore
|
||||
[ # type: ignore
|
||||
task_event_reset,
|
||||
task_incoming_data_get,
|
||||
task_event_remote_closed,
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
|
||||
if task_event_reset in done:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
else:
|
||||
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
|
||||
# is set. The task is probably cancelled.
|
||||
raise Exception(
|
||||
"Should not enter here. "
|
||||
f"It is probably because {task_event_remote_closed} is cancelled."
|
||||
)
|
||||
|
||||
if task_incoming_data_get in done:
|
||||
data = task_incoming_data_get.result()
|
||||
self._buf.extend(data)
|
||||
return
|
||||
|
||||
if task_event_remote_closed in done:
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
else:
|
||||
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
|
||||
# is set. The task is probably cancelled.
|
||||
raise Exception(
|
||||
"Should not enter here. "
|
||||
f"It is probably because {task_event_remote_closed} is cancelled."
|
||||
)
|
||||
|
||||
# TODO: Handle timeout when deadline is used.
|
||||
|
||||
async def _read_until_eof(self) -> bytes:
|
||||
while True:
|
||||
try:
|
||||
await self._wait_for_data()
|
||||
except MplexStreamEOF:
|
||||
break
|
||||
payload = self._buf
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
|
||||
|
@ -128,20 +74,7 @@ class MplexStream(IMuxedStream):
|
|||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n == -1:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0 and self.incoming_data.empty():
|
||||
await self._wait_for_data()
|
||||
# Now we are sure we have something to read.
|
||||
# Try to put enough incoming data into `self._buf`.
|
||||
while len(self._buf) < n:
|
||||
try:
|
||||
self._buf.extend(self.incoming_data.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
return await self.incoming_data.get()
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
"""
|
||||
|
|
|
@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
|
|||
from .constants import MAX_READ_LEN
|
||||
|
||||
|
||||
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
|
||||
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None:
|
||||
peer_id = swarm_1.get_peer_id()
|
||||
addrs = tuple(
|
||||
addr
|
||||
|
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
|
|||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||
await swarm_0.dial_peer(peer_id)
|
||||
await swarm_0.dial_peer(peer_id, nursery)
|
||||
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||
|
||||
|
@ -43,7 +43,9 @@ async def set_up_nodes_by_transport_opt(
|
|||
nodes_list = []
|
||||
for transport_opt in transport_opt_list:
|
||||
node = new_node(transport_opt=transport_opt)
|
||||
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]), nursery=nursery)
|
||||
await node.get_network().listen(
|
||||
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
|
||||
)
|
||||
nodes_list.append(node)
|
||||
return tuple(nodes_list)
|
||||
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import asyncio
|
||||
import trio
|
||||
import logging
|
||||
from socket import socket
|
||||
from typing import List
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.io.trio import TrioReadWriteCloser
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.listener_interface import IListener
|
||||
from libp2p.transport.transport_interface import ITransport
|
||||
from libp2p.transport.typing import THandler
|
||||
from libp2p.io.trio import TrioReadWriteCloser
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.tcp")
|
||||
|
||||
|
@ -44,11 +44,9 @@ class TCPListener(IListener):
|
|||
|
||||
listeners = await nursery.start(
|
||||
serve_tcp,
|
||||
*(
|
||||
handler,
|
||||
int(maddr.value_for_protocol("tcp")),
|
||||
maddr.value_for_protocol("ip4"),
|
||||
),
|
||||
)
|
||||
# self.server = await asyncio.start_server(
|
||||
# self.handler,
|
||||
|
@ -57,7 +55,6 @@ class TCPListener(IListener):
|
|||
# )
|
||||
socket = listeners[0].socket
|
||||
self.multiaddrs.append(_multiaddr_from_socket(socket))
|
||||
logger.debug("Multiaddrs %s", self.multiaddrs)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
|
||||
from typing import Awaitable, Callable, Mapping, Type
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
|
||||
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
|
||||
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import itertools
|
||||
import math
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.abc import Reader
|
||||
|
||||
from .io.utils import read_exactly
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
import trio
|
||||
|
||||
# Unsigned LEB128(varint codec)
|
||||
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import trio
|
||||
import multiaddr
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
|
@ -24,11 +24,11 @@ async def test_simple_messages(nursery):
|
|||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
|
||||
|
||||
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
|
||||
await stream.write(message.encode())
|
||||
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
|
|
@ -1,20 +1,31 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamClosed,
|
||||
MplexStreamEOF,
|
||||
MplexStreamReset,
|
||||
)
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR
|
||||
from libp2p.tools.factories import SwarmFactory
|
||||
from libp2p.tools.utils import connect_swarm
|
||||
|
||||
DATA = b"data_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mplex_stream_read_write(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_write(nursery):
|
||||
swarm0, swarm1 = SwarmFactory(), SwarmFactory()
|
||||
await swarm0.listen(LISTEN_MADDR, nursery=nursery)
|
||||
await swarm1.listen(LISTEN_MADDR, nursery=nursery)
|
||||
await connect_swarm(swarm0, swarm1, nursery)
|
||||
conn_0 = swarm0.connections[swarm1.get_peer_id()]
|
||||
conn_1 = swarm1.connections[swarm0.get_peer_id()]
|
||||
stream_0 = await conn_0.muxed_conn.open_stream()
|
||||
await trio.sleep(1)
|
||||
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
|
||||
await stream_0.write(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import trio
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.utils import TrioQueue
|
||||
|
||||
|
||||
|
@ -16,4 +17,3 @@ async def test_trio_queue():
|
|||
result = await nursery.start(queue_get)
|
||||
|
||||
assert result == 123
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user