diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b72fd25..50d09e7 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable if TYPE_CHECKING: from libp2p.network.swarm import Swarm # noqa: F401 @@ -34,17 +35,27 @@ class SwarmConn(INetConn): if self.event_closed.is_set(): return self.event_closed.set() + self.swarm.remove_conn(self) + await self.conn.close() + + # This is just for cleaning up state. The connection has already been closed. + # We *could* optimize this but it really isn't worth it. + for stream in self.streams: + await stream.reset() + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + asyncio.ensure_future(self._notify_disconnected()) + for task in self._tasks: task.cancel() - # TODO: Reset streams for local. - # TODO: Notify closed. - async def _handle_new_streams(self) -> None: # TODO: Break the loop when anything wrong in the connection. while True: - stream = await self.conn.accept_stream() + try: + stream = await self.conn.accept_stream() + except MuxedConnUnavailable: + break # Asynchronously handle the accepted stream, to avoid blocking the next stream. await self.run_task(self._handle_muxed_stream(stream)) @@ -57,11 +68,16 @@ class SwarmConn(INetConn): async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) + self.streams.add(net_stream) # Call notifiers since event occurred for notifee in self.swarm.notifees: await notifee.opened_stream(self.swarm, net_stream) return net_stream + async def _notify_disconnected(self) -> None: + for notifee in self.swarm.notifees: + await notifee.disconnected(self.swarm, self.conn) + async def start(self) -> None: await self.run_task(self._handle_new_streams()) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index dd4ca6e..5bbbe0a 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -262,7 +262,6 @@ class Swarm(INetwork): if peer_id not in self.connections: return connection = self.connections[peer_id] - del self.connections[peer_id] await connection.close() logger.debug("successfully close the connection to peer %s", peer_id) @@ -277,3 +276,10 @@ class Swarm(INetwork): await notifee.connected(self, muxed_conn) await swarm_conn.start() return swarm_conn + + def remove_conn(self, swarm_conn: SwarmConn) -> None: + print(f"!@# remove_conn: {swarm_conn}") + peer_id = swarm_conn.conn.peer_id + # TODO: Should be changed to remove the exact connection, + # if we have several connections per peer in the future. + del self.connections[peer_id] diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py index 8db5cdc..ce0f92e 100644 --- a/libp2p/stream_muxer/exceptions.py +++ b/libp2p/stream_muxer/exceptions.py @@ -5,11 +5,7 @@ class MuxedConnError(BaseLibp2pError): pass -class MuxedConnShuttingDown(MuxedConnError): - pass - - -class MuxedConnClosed(MuxedConnError): +class MuxedConnUnavailable(MuxedConnError): pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index f42c561..a7be76e 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,7 +1,6 @@ from libp2p.stream_muxer.exceptions import ( - MuxedConnClosed, MuxedConnError, - MuxedConnShuttingDown, + MuxedConnUnavailable, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, @@ -12,11 +11,7 @@ class MplexError(MuxedConnError): pass -class MplexShuttingDown(MuxedConnShuttingDown): - pass - - -class MplexClosed(MuxedConnClosed): +class MplexUnavailable(MuxedConnUnavailable): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 589a623..7f82292 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,6 +1,6 @@ import asyncio from typing import Any # noqa: F401 -from typing import Dict, List, Optional, Tuple +from typing import Awaitable, Dict, List, Optional, Tuple from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -15,7 +15,7 @@ from libp2p.utils import ( from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import MplexClosed, MplexShuttingDown +from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") @@ -76,13 +76,13 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ - # for task in self._tasks: - # task.cancel() - await self.secured_conn.close() + if self.event_shutting_down.is_set(): + return # Set the `event_shutting_down`, to allow graceful shutdown. self.event_shutting_down.set() + await self.secured_conn.close() # Blocked until `close` is finally set. - # await self.event_closed.wait() + await self.event_closed.wait() def is_closed(self) -> bool: """ @@ -119,31 +119,29 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def _wait_until_closed(self, coro) -> Any: + async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: task_coro = asyncio.ensure_future(coro) task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) - done, pending = await asyncio.wait( - [task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED - ) - if task_wait_closed in done: - raise MplexClosed - return task_coro.result() - - async def _wait_until_shutting_down(self, coro) -> Any: - task_coro = asyncio.ensure_future(coro) task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) done, pending = await asyncio.wait( - [task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED + [task_coro, task_wait_closed, task_wait_shutting_down], + return_when=asyncio.FIRST_COMPLETED, ) + for fut in pending: + fut.cancel() + if task_wait_closed in done: + raise MplexUnavailable("Mplex is closed") if task_wait_shutting_down in done: - raise MplexShuttingDown + raise MplexUnavailable("Mplex is shutting down") return task_coro.result() async def accept_stream(self) -> IMuxedStream: """ accepts a muxed stream opened by the other end """ - return await self._wait_until_closed(self.new_stream_queue.get()) + return await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.get() + ) async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -162,7 +160,9 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self.write_to_stream(_bytes) + return await self._wait_until_shutting_down_or_closed( + self.write_to_stream(_bytes) + ) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -180,7 +180,13 @@ class Mplex(IMuxedConn): # TODO Deal with other types of messages using flag (currently _) while True: - channel_id, flag, message = await self.read_message() + try: + channel_id, flag, message = await self._wait_until_shutting_down_or_closed( + self.read_message() + ) + except (MplexUnavailable, ConnectionResetError) as error: + print(f"!@# handle_incoming: read_message: exception={error}") + break if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) is_stream_id_seen: bool @@ -199,8 +205,12 @@ class Mplex(IMuxedConn): mplex_stream = await self._initialize_stream( stream_id, message.decode() ) - # TODO: Check if `self` is shutdown. - await self.new_stream_queue.put(mplex_stream) + try: + await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.put(mplex_stream) + ) + except MplexUnavailable: + break elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, @@ -214,7 +224,12 @@ class Mplex(IMuxedConn): if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 continue - await stream.incoming_data.put(message) + try: + await self._wait_until_shutting_down_or_closed( + stream.incoming_data.put(message) + ) + except MplexUnavailable: + break elif flag in ( HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value, @@ -244,7 +259,6 @@ class Mplex(IMuxedConn): continue async with stream.close_lock: if not stream.event_remote_closed.is_set(): - # TODO: Why? Only if remote is not closed before then reset. stream.event_reset.set() stream.event_remote_closed.set() @@ -260,6 +274,7 @@ class Mplex(IMuxedConn): # Force context switch await asyncio.sleep(0) + await self._cleanup() async def read_message(self) -> Tuple[int, int, bytes]: """ @@ -284,3 +299,16 @@ class Mplex(IMuxedConn): channel_id = header >> 3 return channel_id, flag, message + + async def _cleanup(self) -> None: + if not self.event_shutting_down.is_set(): + self.event_shutting_down.set() + async with self.streams_lock: + for stream in self.streams.values(): + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_remote_closed.set() + stream.event_reset.set() + stream.event_local_closed.set() + self.streams = None + self.event_closed.set() diff --git a/libp2p/utils.py b/libp2p/utils.py index c69f61b..0c1eea8 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int: if shift > SHIFT_64_BIT_MAX: raise ParseError("TODO: better exception msg: Integer is too large...") - byte = await reader.read(1) - - try: - value = byte[0] - except IndexError: - raise ParseError( - "Unexpected end of stream while parsing LEB128 encoded integer" - ) + byte = await read_exactly(reader, 1) + value = byte[0] res += (value & LOW_MASK) << shift diff --git a/tests/interop/test_bindings.py b/tests/interop/test_bindings.py index 1189e0b..1e78ff4 100644 --- a/tests/interop/test_bindings.py +++ b/tests/interop/test_bindings.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from .utils import connect @@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds): # Test: `disconnect` from Go await p2pd.control.disconnect(host.get_id()) # FIXME: Failed to handle disconnect - # assert len(host.get_network().connections) == 0 + await asyncio.sleep(0.01) + assert len(host.get_network().connections) == 0