Fix Mplex and Swarm

This commit is contained in:
mhchia 2019-11-29 19:09:56 +08:00
parent ec43c25b45
commit 1e600ea7e0
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 232 additions and 122 deletions

View File

@ -23,6 +23,10 @@ class TrioReadWriteCloser(ReadWriteCloser):
raise IOException(error)
async def read(self, n: int = -1) -> bytes:
if n == 0:
# Check point
await trio.sleep(0)
return b""
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)

View File

@ -50,8 +50,11 @@ class SwarmConn(INetConn, Service):
await self._notify_disconnected()
async def _handle_new_streams(self) -> None:
while True:
while self.manager.is_running:
try:
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
)
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
@ -60,6 +63,9 @@ class SwarmConn(INetConn, Service):
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
self.manager.run_task(self._handle_muxed_stream, stream)
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
)
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:

View File

@ -206,8 +206,7 @@ class Swarm(INetwork, Service):
logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
event = trio.Event()
await event.wait()
await trio.sleep_forever()
try:
# Success
@ -240,7 +239,7 @@ class Swarm(INetwork, Service):
# await asyncio.gather(
# *[connection.close() for connection in self.connections.values()]
# )
self.manager.stop()
await self.manager.stop()
await self.manager.wait_finished()
logger.debug("swarm successfully closed")

View File

@ -1,3 +1,4 @@
import math
import asyncio
import logging
from typing import Any # noqa: F401
@ -18,7 +19,6 @@ from libp2p.utils import (
encode_uvarint,
encode_varint_prefixed,
read_varint_prefixed_bytes,
TrioQueue,
)
from .constants import HeaderTags
@ -41,7 +41,10 @@ class Mplex(IMuxedConn, Service):
next_channel_id: int
streams: Dict[StreamID, MplexStream]
streams_lock: trio.Lock
new_stream_queue: "TrioQueue[IMuxedStream]"
streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
event_shutting_down: trio.Event
event_closed: trio.Event
@ -64,7 +67,10 @@ class Mplex(IMuxedConn, Service):
# Mapping from stream ID -> buffer of messages for that stream
self.streams = {}
self.streams_lock = trio.Lock()
self.new_stream_queue = TrioQueue()
self.streams_msg_channels = {}
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.new_stream_send_channel = send_channel
self.new_stream_receive_channel = receive_channel
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
@ -105,9 +111,13 @@ class Mplex(IMuxedConn, Service):
return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self)
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
# `send_channel.send`.
send_channel, receive_channel = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, receive_channel)
async with self.streams_lock:
self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
return stream
async def open_stream(self) -> IMuxedStream:
@ -126,7 +136,10 @@ class Mplex(IMuxedConn, Service):
async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end."""
return await self.new_stream_queue.get()
try:
return await self.new_stream_receive_channel.receive()
except (trio.ClosedResourceError, trio.EndOfChannel):
raise MplexUnavailable
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -138,6 +151,9 @@ class Mplex(IMuxedConn, Service):
:param data: data to send in the message
:param stream_id: stream the message is in
"""
print(
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
)
# << by 3, then or with flag
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
@ -162,14 +178,21 @@ class Mplex(IMuxedConn, Service):
"""Read a message off of the secured connection and add it to the
corresponding message buffer."""
while True:
while self.manager.is_running:
try:
print(
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
)
await self._handle_incoming_message()
print(
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
)
except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
break
# Force context switch
await trio.sleep(0)
print(f"!@# handle_incoming: {self._id}: leaving")
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
@ -181,51 +204,73 @@ class Mplex(IMuxedConn, Service):
:return: stream_id, flag, message contents
"""
# FIXME: No timeout is used in Go implementation.
try:
header = await decode_uvarint_from_stream(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
f"failed to read the header correctly from the underlying connection: {error}"
)
try:
message = await read_varint_prefixed_bytes(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
"failed to read messages correctly from the underlying connection"
) from error
except asyncio.TimeoutError as error:
raise MplexUnavailable(
"failed to read more message body within the timeout"
) from error
"failed to read the message body correctly from the underlying connection: "
f"{error}"
)
flag = header & 0x07
channel_id = header >> 3
return channel_id, flag, message
@property
def _id(self) -> int:
return 0 if self.is_initiator else 1
async def _handle_incoming_message(self) -> None:
"""
Read and handle a new incoming message.
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
print(f"!@# _handle_incoming_message: {self._id}: before reading")
channel_id, flag, message = await self.read_message()
print(
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
print(f"!@# _handle_incoming_message: {self._id}: 2")
if flag == HeaderTags.NewStream.value:
print(f"!@# _handle_incoming_message: {self._id}: 3")
await self._handle_new_stream(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 4")
elif flag in (
HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
print(f"!@# _handle_incoming_message: {self._id}: 5")
await self._handle_message(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 6")
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 7")
await self._handle_close(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 8")
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 9")
await self._handle_reset(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 10")
else:
print(f"!@# _handle_incoming_message: {self._id}: 11")
# Receives messages with an unknown flag
# TODO: logging
async with self.streams_lock:
print(f"!@# _handle_incoming_message: {self._id}: 12")
if stream_id in self.streams:
print(f"!@# _handle_incoming_message: {self._id}: 13")
stream = self.streams[stream_id]
await stream.reset()
print(f"!@# _handle_incoming_message: {self._id}: 14")
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
@ -235,43 +280,65 @@ class Mplex(IMuxedConn, Service):
f"received NewStream message for existing stream: {stream_id}"
)
mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self.new_stream_queue.put(mplex_stream)
try:
await self.new_stream_send_channel.send(mplex_stream)
except (trio.BrokenResourceError, trio.EndOfChannel):
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
print(
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
)
async with self.streams_lock:
print(f"!@# _handle_message: {self._id}: 1")
if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
print(f"!@# _handle_message: {self._id}: 2")
return
print(f"!@# _handle_message: {self._id}: 3")
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock:
print(f"!@# _handle_message: {self._id}: 4")
if stream.event_remote_closed.is_set():
print(f"!@# _handle_message: {self._id}: 5")
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return
await stream.incoming_data.put(message)
print(f"!@# _handle_message: {self._id}: 6")
await send_channel.send(message)
print(f"!@# _handle_message: {self._id}: 7")
async def _handle_close(self, stream_id: StreamID) -> None:
print(f"!@# _handle_close: {self._id}: step=0")
async with self.streams_lock:
if stream_id not in self.streams:
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
print(f"!@# _handle_close: {self._id}: step=1")
await send_channel.aclose()
print(f"!@# _handle_close: {self._id}: step=2")
# NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection.
async with stream.close_lock:
if stream.event_remote_closed.is_set():
return
print(f"!@# _handle_close: {self._id}: step=3")
is_local_closed: bool
async with stream.close_lock:
stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set()
print(f"!@# _handle_close: {self._id}: step=4")
# If local is also closed, both sides are closed. Then, we should clean up
# the entry of this stream, to avoid others from accessing it.
if is_local_closed:
async with self.streams_lock:
if stream_id in self.streams:
del self.streams[stream_id]
print(f"!@# _handle_close: {self._id}: step=5")
async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock:
@ -279,11 +346,11 @@ class Mplex(IMuxedConn, Service):
# This is *ok*. We forget the stream on reset.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_reset.set()
stream.event_remote_closed.set()
# If local is not closed, we should close it.
if not stream.event_local_closed.is_set():
@ -291,16 +358,21 @@ class Mplex(IMuxedConn, Service):
async with self.streams_lock:
if stream_id in self.streams:
del self.streams[stream_id]
del self.streams_msg_channels[stream_id]
async def _cleanup(self) -> None:
if not self.event_shutting_down.is_set():
self.event_shutting_down.set()
async with self.streams_lock:
for stream in self.streams.values():
for stream_id, stream in self.streams.items():
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_remote_closed.set()
stream.event_reset.set()
stream.event_local_closed.set()
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
self.streams = None
self.event_closed.set()
await self.new_stream_send_channel.aclose()
await self.new_stream_receive_channel.aclose()

View File

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
import trio
from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.utils import IQueue, TrioQueue
from .constants import HeaderTags
from .datastructures import StreamID
@ -24,10 +23,11 @@ class MplexStream(IMuxedStream):
read_deadline: int
write_deadline: int
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: IQueue[bytes]
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
event_local_closed: trio.Event
event_remote_closed: trio.Event
@ -35,7 +35,13 @@ class MplexStream(IMuxedStream):
_buf: bytearray
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
def __init__(
self,
name: str,
stream_id: StreamID,
muxed_conn: "Mplex",
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
) -> None:
"""
create new MuxedStream in muxer.
@ -51,13 +57,30 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.incoming_data = TrioQueue()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@property
def is_initiator(self) -> bool:
return self.stream_id.is_initiator
async def _read_until_eof(self) -> bytes:
async for data in self.incoming_data_channel:
self._buf.extend(data)
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
def _read_return_when_blocked(self) -> bytes:
buf = bytearray()
while True:
try:
data = self.incoming_data_channel.receive_nowait()
buf.extend(data)
except (trio.WouldBlock, trio.EndOfChannel):
break
return buf
async def read(self, n: int = -1) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
@ -73,7 +96,40 @@ class MplexStream(IMuxedStream):
)
if self.event_reset.is_set():
raise MplexStreamReset
return await self.incoming_data.get()
if n == -1:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is no data,
# and then return.
try:
data = self.incoming_data_channel.receive_nowait()
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with `receive` and
# catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
# for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(data)
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> int:
"""
@ -99,22 +155,26 @@ class MplexStream(IMuxedStream):
if self.event_local_closed.is_set():
return
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
)
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id)
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
_is_remote_closed: bool
async with self.close_lock:
self.event_local_closed.set()
_is_remote_closed = self.event_remote_closed.is_set()
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict.
async with self.muxed_conn.streams_lock:
if self.stream_id in self.muxed_conn.streams:
del self.muxed_conn.streams[self.stream_id]
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
async def reset(self) -> None:
"""closes both ends of the stream tells this remote side to hang up."""
@ -132,14 +192,15 @@ class MplexStream(IMuxedStream):
if self.is_initiator
else HeaderTags.ResetReceiver
)
async with trio.open_nursery() as nursery:
nursery.start_soon(
self.muxed_conn.manager.run_task(
self.muxed_conn.send_message, flag, None, self.stream_id
)
self.event_local_closed.set()
self.event_remote_closed.set()
await self.incoming_data_channel.aclose()
async with self.muxed_conn.streams_lock:
if (
self.muxed_conn.streams is not None

View File

@ -205,7 +205,7 @@ async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, Mplex
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection")
raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)

View File

@ -1,8 +1,5 @@
import itertools
import math
from typing import Generic, TypeVar
import trio
from libp2p.exceptions import ParseError
from libp2p.io.abc import Reader
@ -98,25 +95,3 @@ async def read_fixedint_prefixed(reader: Reader) -> bytes:
len_bytes = await reader.read(SIZE_LEN_BYTES)
len_int = int.from_bytes(len_bytes, "big")
return await reader.read(len_int)
TItem = TypeVar("TItem")
class IQueue(Generic[TItem]):
async def put(self, item: TItem):
...
async def get(self) -> TItem:
...
class TrioQueue(IQueue):
def __init__(self):
self.send_channel, self.receive_channel = trio.open_memory_channel(0)
async def put(self, item: TItem):
await self.send_channel.send(item)
async def get(self) -> TItem:
return await self.receive_channel.receive()

View File

@ -1,8 +1,5 @@
import asyncio
import pytest
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import HostFactory
@ -17,17 +14,6 @@ def num_hosts():
@pytest.fixture
async def hosts(num_hosts, is_host_secure):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
async def hosts(num_hosts, is_host_secure, nursery):
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)

View File

@ -1,5 +1,3 @@
import asyncio
import pytest
from libp2p.tools.factories import (

View File

@ -1,5 +1,3 @@
import asyncio
import pytest
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory

View File

@ -1,9 +1,9 @@
import asyncio
import trio
import pytest
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair
@ -16,19 +16,19 @@ async def test_mplex_conn(mplex_conn_pair):
# Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1
# Test: From another side.
stream_1 = await conn_1.open_stream()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2
# Close from one side.
await conn_0.close()
# Sleep for a while for both side to handle `close`.
await asyncio.sleep(0.01)
await trio.sleep(0.01)
# Test: Both side is closed.
assert conn_0.event_shutting_down.is_set()
assert conn_0.event_closed.is_set()

View File

@ -1,5 +1,6 @@
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
@ -37,10 +38,10 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
async def read_until_eof():
read_bytes.extend(await stream_1.read())
task = trio.ensure_future(read_until_eof())
expected_data = bytearray()
async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
@ -54,10 +55,8 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await trio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
assert read_bytes == expected_data
@pytest.mark.trio
@ -65,9 +64,39 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA)
await stream_0.close()
assert not stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
await stream_0.close()
assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
# await trio.sleep(100000)
await wait_all_tasks_blocked()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
print("!@# sleeping")
print("!@# result=", stream_1.event_remote_closed.is_set())
# await trio.sleep_forever()
assert stream_1.event_remote_closed.is_set()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN)
@ -87,7 +116,8 @@ async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await trio.sleep(0.01)
await trio.sleep(0.1)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)

View File

@ -1,19 +0,0 @@
import pytest
import trio
from libp2p.utils import TrioQueue
@pytest.mark.trio
async def test_trio_queue():
queue = TrioQueue()
async def queue_get(task_status=None):
result = await queue.get()
task_status.started(result)
async with trio.open_nursery() as nursery:
nursery.start_soon(queue.put, 123)
result = await nursery.start(queue_get)
assert result == 123