Fix Mplex
and Swarm
This commit is contained in:
parent
ec43c25b45
commit
1e600ea7e0
|
@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser):
|
|||
raise IOException(error)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
if n == 0:
|
||||
# Check point
|
||||
await trio.sleep(0)
|
||||
return b""
|
||||
max_bytes = n if n != -1 else None
|
||||
try:
|
||||
return await self.stream.receive_some(max_bytes)
|
||||
|
|
|
@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service):
|
|||
await self._notify_disconnected()
|
||||
|
||||
async def _handle_new_streams(self) -> None:
|
||||
while True:
|
||||
while self.manager.is_running:
|
||||
try:
|
||||
print(
|
||||
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
|
||||
)
|
||||
stream = await self.muxed_conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
# If there is anything wrong in the MuxedConn,
|
||||
|
@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service):
|
|||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
self.manager.run_task(self._handle_muxed_stream, stream)
|
||||
|
||||
print(
|
||||
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
|
||||
)
|
||||
await self.close()
|
||||
|
||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
||||
|
|
|
@ -206,8 +206,7 @@ class Swarm(INetwork, Service):
|
|||
logger.debug("successfully opened connection to peer %s", peer_id)
|
||||
# FIXME: This is a intentional barrier to prevent from the handler exiting and
|
||||
# closing the connection.
|
||||
event = trio.Event()
|
||||
await event.wait()
|
||||
await trio.sleep_forever()
|
||||
|
||||
try:
|
||||
# Success
|
||||
|
@ -240,7 +239,7 @@ class Swarm(INetwork, Service):
|
|||
# await asyncio.gather(
|
||||
# *[connection.close() for connection in self.connections.values()]
|
||||
# )
|
||||
self.manager.stop()
|
||||
await self.manager.stop()
|
||||
await self.manager.wait_finished()
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any # noqa: F401
|
||||
|
@ -18,7 +19,6 @@ from libp2p.utils import (
|
|||
encode_uvarint,
|
||||
encode_varint_prefixed,
|
||||
read_varint_prefixed_bytes,
|
||||
TrioQueue,
|
||||
)
|
||||
|
||||
from .constants import HeaderTags
|
||||
|
@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service):
|
|||
next_channel_id: int
|
||||
streams: Dict[StreamID, MplexStream]
|
||||
streams_lock: trio.Lock
|
||||
new_stream_queue: "TrioQueue[IMuxedStream]"
|
||||
streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
|
||||
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
|
||||
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
|
||||
|
||||
event_shutting_down: trio.Event
|
||||
event_closed: trio.Event
|
||||
|
||||
|
@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service):
|
|||
# Mapping from stream ID -> buffer of messages for that stream
|
||||
self.streams = {}
|
||||
self.streams_lock = trio.Lock()
|
||||
self.new_stream_queue = TrioQueue()
|
||||
self.streams_msg_channels = {}
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
self.new_stream_send_channel = send_channel
|
||||
self.new_stream_receive_channel = receive_channel
|
||||
self.event_shutting_down = trio.Event()
|
||||
self.event_closed = trio.Event()
|
||||
|
||||
|
@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service):
|
|||
return next_id
|
||||
|
||||
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||
stream = MplexStream(name, stream_id, self)
|
||||
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
|
||||
# `send_channel.send`.
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
stream = MplexStream(name, stream_id, self, receive_channel)
|
||||
async with self.streams_lock:
|
||||
self.streams[stream_id] = stream
|
||||
self.streams_msg_channels[stream_id] = send_channel
|
||||
return stream
|
||||
|
||||
async def open_stream(self) -> IMuxedStream:
|
||||
|
@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service):
|
|||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
"""accepts a muxed stream opened by the other end."""
|
||||
return await self.new_stream_queue.get()
|
||||
try:
|
||||
return await self.new_stream_receive_channel.receive()
|
||||
except (trio.ClosedResourceError, trio.EndOfChannel):
|
||||
raise MplexUnavailable
|
||||
|
||||
async def send_message(
|
||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||
|
@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service):
|
|||
:param data: data to send in the message
|
||||
:param stream_id: stream the message is in
|
||||
"""
|
||||
print(
|
||||
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
|
||||
)
|
||||
# << by 3, then or with flag
|
||||
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
|
||||
|
||||
|
@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service):
|
|||
"""Read a message off of the secured connection and add it to the
|
||||
corresponding message buffer."""
|
||||
|
||||
while True:
|
||||
while self.manager.is_running:
|
||||
try:
|
||||
print(
|
||||
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
|
||||
)
|
||||
await self._handle_incoming_message()
|
||||
print(
|
||||
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
|
||||
)
|
||||
except MplexUnavailable as e:
|
||||
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
|
||||
break
|
||||
# Force context switch
|
||||
await trio.sleep(0)
|
||||
|
||||
print(f"!@# handle_incoming: {self._id}: leaving")
|
||||
# If we enter here, it means this connection is shutting down.
|
||||
# We should clean things up.
|
||||
await self._cleanup()
|
||||
|
@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service):
|
|||
:return: stream_id, flag, message contents
|
||||
"""
|
||||
|
||||
# FIXME: No timeout is used in Go implementation.
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
f"failed to read the header correctly from the underlying connection: {error}"
|
||||
)
|
||||
try:
|
||||
message = await read_varint_prefixed_bytes(self.secured_conn)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read messages correctly from the underlying connection"
|
||||
) from error
|
||||
except asyncio.TimeoutError as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read more message body within the timeout"
|
||||
) from error
|
||||
"failed to read the message body correctly from the underlying connection: "
|
||||
f"{error}"
|
||||
)
|
||||
|
||||
flag = header & 0x07
|
||||
channel_id = header >> 3
|
||||
|
||||
return channel_id, flag, message
|
||||
|
||||
@property
|
||||
def _id(self) -> int:
|
||||
return 0 if self.is_initiator else 1
|
||||
|
||||
async def _handle_incoming_message(self) -> None:
|
||||
"""
|
||||
Read and handle a new incoming message.
|
||||
|
||||
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||
"""
|
||||
print(f"!@# _handle_incoming_message: {self._id}: before reading")
|
||||
channel_id, flag, message = await self.read_message()
|
||||
print(
|
||||
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
|
||||
)
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 2")
|
||||
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 3")
|
||||
await self._handle_new_stream(stream_id, message)
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 4")
|
||||
elif flag in (
|
||||
HeaderTags.MessageInitiator.value,
|
||||
HeaderTags.MessageReceiver.value,
|
||||
):
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 5")
|
||||
await self._handle_message(stream_id, message)
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 6")
|
||||
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 7")
|
||||
await self._handle_close(stream_id)
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 8")
|
||||
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 9")
|
||||
await self._handle_reset(stream_id)
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 10")
|
||||
else:
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 11")
|
||||
# Receives messages with an unknown flag
|
||||
# TODO: logging
|
||||
async with self.streams_lock:
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 12")
|
||||
if stream_id in self.streams:
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 13")
|
||||
stream = self.streams[stream_id]
|
||||
await stream.reset()
|
||||
print(f"!@# _handle_incoming_message: {self._id}: 14")
|
||||
|
||||
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
|
||||
async with self.streams_lock:
|
||||
|
@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service):
|
|||
f"received NewStream message for existing stream: {stream_id}"
|
||||
)
|
||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||
await self.new_stream_queue.put(mplex_stream)
|
||||
try:
|
||||
await self.new_stream_send_channel.send(mplex_stream)
|
||||
except (trio.BrokenResourceError, trio.EndOfChannel):
|
||||
raise MplexUnavailable
|
||||
|
||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||
print(
|
||||
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
|
||||
)
|
||||
async with self.streams_lock:
|
||||
print(f"!@# _handle_message: {self._id}: 1")
|
||||
if stream_id not in self.streams:
|
||||
# We receive a message of the stream `stream_id` which is not accepted
|
||||
# before. It is abnormal. Possibly disconnect?
|
||||
# TODO: Warn and emit logs about this.
|
||||
print(f"!@# _handle_message: {self._id}: 2")
|
||||
return
|
||||
print(f"!@# _handle_message: {self._id}: 3")
|
||||
stream = self.streams[stream_id]
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
async with stream.close_lock:
|
||||
print(f"!@# _handle_message: {self._id}: 4")
|
||||
if stream.event_remote_closed.is_set():
|
||||
print(f"!@# _handle_message: {self._id}: 5")
|
||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||
return
|
||||
await stream.incoming_data.put(message)
|
||||
print(f"!@# _handle_message: {self._id}: 6")
|
||||
await send_channel.send(message)
|
||||
print(f"!@# _handle_message: {self._id}: 7")
|
||||
|
||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||
print(f"!@# _handle_close: {self._id}: step=0")
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
# Ignore unmatched messages for now.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
print(f"!@# _handle_close: {self._id}: step=1")
|
||||
await send_channel.aclose()
|
||||
print(f"!@# _handle_close: {self._id}: step=2")
|
||||
# NOTE: If remote is already closed, then return: Technically a bug
|
||||
# on the other side. We should consider killing the connection.
|
||||
async with stream.close_lock:
|
||||
if stream.event_remote_closed.is_set():
|
||||
return
|
||||
print(f"!@# _handle_close: {self._id}: step=3")
|
||||
is_local_closed: bool
|
||||
async with stream.close_lock:
|
||||
stream.event_remote_closed.set()
|
||||
is_local_closed = stream.event_local_closed.is_set()
|
||||
print(f"!@# _handle_close: {self._id}: step=4")
|
||||
# If local is also closed, both sides are closed. Then, we should clean up
|
||||
# the entry of this stream, to avoid others from accessing it.
|
||||
if is_local_closed:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
del self.streams[stream_id]
|
||||
print(f"!@# _handle_close: {self._id}: step=5")
|
||||
|
||||
async def _handle_reset(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
|
@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service):
|
|||
# This is *ok*. We forget the stream on reset.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
await send_channel.aclose()
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
stream.event_reset.set()
|
||||
|
||||
stream.event_remote_closed.set()
|
||||
# If local is not closed, we should close it.
|
||||
if not stream.event_local_closed.is_set():
|
||||
|
@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service):
|
|||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
del self.streams[stream_id]
|
||||
del self.streams_msg_channels[stream_id]
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
if not self.event_shutting_down.is_set():
|
||||
self.event_shutting_down.set()
|
||||
async with self.streams_lock:
|
||||
for stream in self.streams.values():
|
||||
for stream_id, stream in self.streams.items():
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
stream.event_remote_closed.set()
|
||||
stream.event_reset.set()
|
||||
stream.event_local_closed.set()
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
await send_channel.aclose()
|
||||
self.streams = None
|
||||
self.event_closed.set()
|
||||
await self.new_stream_send_channel.aclose()
|
||||
await self.new_stream_receive_channel.aclose()
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
|
|||
import trio
|
||||
|
||||
from libp2p.stream_muxer.abc import IMuxedStream
|
||||
from libp2p.utils import IQueue, TrioQueue
|
||||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
|
@ -24,10 +23,11 @@ class MplexStream(IMuxedStream):
|
|||
read_deadline: int
|
||||
write_deadline: int
|
||||
|
||||
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
incoming_data: IQueue[bytes]
|
||||
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
|
||||
|
||||
event_local_closed: trio.Event
|
||||
event_remote_closed: trio.Event
|
||||
|
@ -35,7 +35,13 @@ class MplexStream(IMuxedStream):
|
|||
|
||||
_buf: bytearray
|
||||
|
||||
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
stream_id: StreamID,
|
||||
muxed_conn: "Mplex",
|
||||
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
|
||||
) -> None:
|
||||
"""
|
||||
create new MuxedStream in muxer.
|
||||
|
||||
|
@ -51,13 +57,30 @@ class MplexStream(IMuxedStream):
|
|||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.incoming_data = TrioQueue()
|
||||
self.incoming_data_channel = incoming_data_channel
|
||||
self._buf = bytearray()
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
return self.stream_id.is_initiator
|
||||
|
||||
async def _read_until_eof(self) -> bytes:
|
||||
async for data in self.incoming_data_channel:
|
||||
self._buf.extend(data)
|
||||
payload = self._buf
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
def _read_return_when_blocked(self) -> bytes:
|
||||
buf = bytearray()
|
||||
while True:
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
buf.extend(data)
|
||||
except (trio.WouldBlock, trio.EndOfChannel):
|
||||
break
|
||||
return buf
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
|
||||
|
@ -73,7 +96,40 @@ class MplexStream(IMuxedStream):
|
|||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
return await self.incoming_data.get()
|
||||
if n == -1:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until there is no data,
|
||||
# and then return.
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
except trio.EndOfChannel:
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with `receive` and
|
||||
# catch all kinds of errors here.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
|
||||
# for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset. "
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(data)
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
"""
|
||||
|
@ -99,22 +155,26 @@ class MplexStream(IMuxedStream):
|
|||
if self.event_local_closed.is_set():
|
||||
return
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
|
||||
flag = (
|
||||
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
||||
)
|
||||
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
|
||||
_is_remote_closed: bool
|
||||
async with self.close_lock:
|
||||
self.event_local_closed.set()
|
||||
_is_remote_closed = self.event_remote_closed.is_set()
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
|
||||
if _is_remote_closed:
|
||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if self.stream_id in self.muxed_conn.streams:
|
||||
del self.muxed_conn.streams[self.stream_id]
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""closes both ends of the stream tells this remote side to hang up."""
|
||||
|
@ -132,14 +192,15 @@ class MplexStream(IMuxedStream):
|
|||
if self.is_initiator
|
||||
else HeaderTags.ResetReceiver
|
||||
)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
self.muxed_conn.manager.run_task(
|
||||
self.muxed_conn.send_message, flag, None, self.stream_id
|
||||
)
|
||||
|
||||
self.event_local_closed.set()
|
||||
self.event_remote_closed.set()
|
||||
|
||||
await self.incoming_data_channel.aclose()
|
||||
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if (
|
||||
self.muxed_conn.streams is not None
|
||||
|
|
|
@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex
|
|||
stream_1: MplexStream
|
||||
async with mplex_conn_1.streams_lock:
|
||||
if len(mplex_conn_1.streams) != 1:
|
||||
raise Exception("Mplex should not have any stream upon connection")
|
||||
raise Exception("Mplex should not have any other stream")
|
||||
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
||||
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
|
||||
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import itertools
|
||||
import math
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.abc import Reader
|
||||
|
@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes:
|
|||
len_bytes = await reader.read(SIZE_LEN_BYTES)
|
||||
len_int = int.from_bytes(len_bytes, "big")
|
||||
return await reader.read(len_int)
|
||||
|
||||
|
||||
TItem = TypeVar("TItem")
|
||||
|
||||
|
||||
class IQueue(Generic[TItem]):
|
||||
async def put(self, item: TItem):
|
||||
...
|
||||
|
||||
async def get(self) -> TItem:
|
||||
...
|
||||
|
||||
|
||||
class TrioQueue(IQueue):
|
||||
def __init__(self):
|
||||
self.send_channel, self.receive_channel = trio.open_memory_channel(0)
|
||||
|
||||
async def put(self, item: TItem):
|
||||
await self.send_channel.send(item)
|
||||
|
||||
async def get(self) -> TItem:
|
||||
return await self.receive_channel.receive()
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from libp2p.tools.factories import HostFactory
|
||||
|
||||
|
||||
|
@ -17,17 +14,6 @@ def num_hosts():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def hosts(num_hosts, is_host_secure):
|
||||
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
|
||||
await asyncio.gather(
|
||||
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
|
||||
)
|
||||
try:
|
||||
async def hosts(num_hosts, is_host_secure, nursery):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
|
||||
yield _hosts
|
||||
finally:
|
||||
# TODO: It's possible that `close` raises exceptions currently,
|
||||
# due to the connection reset things. Though we don't care much about that when
|
||||
# cleaning up the tasks, it is probably better to handle the exceptions properly.
|
||||
await asyncio.gather(
|
||||
*[_host.close() for _host in _hosts], return_exceptions=True
|
||||
)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import (
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import trio
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_conn(mplex_conn_pair):
|
||||
conn_0, conn_1 = mplex_conn_pair
|
||||
|
||||
|
@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair):
|
|||
|
||||
# Test: Open a stream, and both side get 1 more stream.
|
||||
stream_0 = await conn_0.open_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 1
|
||||
assert len(conn_1.streams) == 1
|
||||
# Test: From another side.
|
||||
stream_1 = await conn_1.open_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 2
|
||||
assert len(conn_1.streams) == 2
|
||||
|
||||
# Close from one side.
|
||||
await conn_0.close()
|
||||
# Sleep for a while for both side to handle `close`.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
# Test: Both side is closed.
|
||||
assert conn_0.event_shutting_down.is_set()
|
||||
assert conn_0.event_closed.is_set()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamClosed,
|
||||
|
@ -37,10 +38,10 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
|||
async def read_until_eof():
|
||||
read_bytes.extend(await stream_1.read())
|
||||
|
||||
task = trio.ensure_future(read_until_eof())
|
||||
|
||||
expected_data = bytearray()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read_until_eof)
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
|
@ -54,10 +55,8 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
|||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
await trio.sleep(0.01)
|
||||
assert read_bytes == expected_data
|
||||
|
||||
task.cancel()
|
||||
assert read_bytes == expected_data
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
|
@ -65,9 +64,39 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
|||
stream_0, stream_1 = mplex_stream_pair
|
||||
assert not stream_1.event_remote_closed.is_set()
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
assert not stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
await stream_0.close()
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
print(
|
||||
"!@# ",
|
||||
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_0.muxed_conn.event_closed.is_set(),
|
||||
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_1.muxed_conn.event_closed.is_set(),
|
||||
)
|
||||
# await trio.sleep(100000)
|
||||
await wait_all_tasks_blocked()
|
||||
print(
|
||||
"!@# ",
|
||||
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_0.muxed_conn.event_closed.is_set(),
|
||||
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_1.muxed_conn.event_closed.is_set(),
|
||||
)
|
||||
print("!@# sleeping")
|
||||
print("!@# result=", stream_1.event_remote_closed.is_set())
|
||||
# await trio.sleep_forever()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
print(
|
||||
"!@# ",
|
||||
stream_0.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_0.muxed_conn.event_closed.is_set(),
|
||||
stream_1.muxed_conn.event_shutting_down.is_set(),
|
||||
stream_1.muxed_conn.event_closed.is_set(),
|
||||
)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
with pytest.raises(MplexStreamEOF):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
|||
await stream_0.write(DATA)
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await trio.sleep(0.01)
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.utils import TrioQueue
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_queue():
|
||||
queue = TrioQueue()
|
||||
|
||||
async def queue_get(task_status=None):
|
||||
result = await queue.get()
|
||||
task_status.started(result)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(queue.put, 123)
|
||||
result = await nursery.start(queue_get)
|
||||
|
||||
assert result == 123
|
Loading…
Reference in New Issue
Block a user