Refine Mplex.close
and SwarmConn.close
Ensure `close` cleans up things and cancel the service finally.
This commit is contained in:
parent
d847e78a83
commit
fb0519129d
|
@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service):
|
|||
self.streams = set()
|
||||
self.event_closed = trio.Event()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.event_closed.is_set()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.event_closed.is_set():
|
||||
return
|
||||
self.event_closed.set()
|
||||
await self._cleanup()
|
||||
# Cancel service
|
||||
await self.manager.stop()
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
self.swarm.remove_conn(self)
|
||||
|
||||
await self.muxed_conn.close()
|
||||
|
@ -51,16 +60,16 @@ class SwarmConn(INetConn, Service):
|
|||
while self.manager.is_running:
|
||||
try:
|
||||
stream = await self.muxed_conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
# If there is anything wrong in the MuxedConn,
|
||||
# we should break the loop and close the connection.
|
||||
break
|
||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
except MuxedConnUnavailable:
|
||||
break
|
||||
self.manager.run_task(self._handle_muxed_stream, stream)
|
||||
|
||||
await self.close()
|
||||
|
||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||
net_stream = await self._add_stream(muxed_stream)
|
||||
if self.swarm.common_stream_handler is not None:
|
||||
try:
|
||||
await self.swarm.common_stream_handler(net_stream)
|
||||
# TODO: More exact exceptions
|
||||
|
@ -69,11 +78,6 @@ class SwarmConn(INetConn, Service):
|
|||
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
|
||||
self.remove_stream(net_stream)
|
||||
|
||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||
net_stream = await self._add_stream(muxed_stream)
|
||||
if self.swarm.common_stream_handler is not None:
|
||||
await self._call_stream_handler(net_stream)
|
||||
|
||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||
net_stream = NetStream(muxed_stream)
|
||||
self.streams.add(net_stream)
|
||||
|
@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service):
|
|||
await self.swarm.notify_disconnected(self)
|
||||
|
||||
async def run(self) -> None:
|
||||
await self._handle_new_streams()
|
||||
self.manager.run_task(self._handle_new_streams)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
async def new_stream(self) -> NetStream:
|
||||
muxed_stream = await self.muxed_conn.open_stream()
|
||||
|
|
|
@ -44,6 +44,7 @@ class Swarm(INetwork, Service):
|
|||
common_stream_handler: Optional[StreamHandlerFn]
|
||||
|
||||
notifees: List[INotifee]
|
||||
event_closed: trio.Event
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -62,6 +63,8 @@ class Swarm(INetwork, Service):
|
|||
# Create Notifee array
|
||||
self.notifees = []
|
||||
|
||||
self.event_closed = trio.Event()
|
||||
|
||||
self.common_stream_handler = None
|
||||
|
||||
async def run(self) -> None:
|
||||
|
@ -227,10 +230,19 @@ class Swarm(INetwork, Service):
|
|||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
# TODO: Prevent from new listeners and conns being added.
|
||||
if self.event_closed.is_set():
|
||||
return
|
||||
self.event_closed.set()
|
||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
||||
async with trio.open_nursery() as nursery:
|
||||
for conn in self.connections.values():
|
||||
nursery.start_soon(conn.close)
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in self.listeners.values():
|
||||
nursery.start_soon(listener.close)
|
||||
|
||||
# Cancel tasks
|
||||
await self.manager.stop()
|
||||
await self.manager.wait_finished()
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
|
@ -270,8 +282,6 @@ class Swarm(INetwork, Service):
|
|||
|
||||
# Notifee
|
||||
|
||||
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
|
||||
|
||||
def register_notifee(self, notifee: INotifee) -> None:
|
||||
"""
|
||||
:param notifee: object implementing Notifee interface
|
||||
|
|
|
@ -58,11 +58,11 @@ class TopicValidator(NamedTuple):
|
|||
|
||||
|
||||
# TODO: Add interface for Pubsub
|
||||
class BasePubsub(ABC):
|
||||
class IPubsub(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Pubsub(BasePubsub, Service):
|
||||
class Pubsub(IPubsub, Service):
|
||||
|
||||
host: IHost
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI):
|
|||
async def close(self) -> None:
|
||||
"""close connection."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
|
|
|
@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service):
|
|||
await self.secured_conn.close()
|
||||
# Blocked until `close` is finally set.
|
||||
await self.event_closed.wait()
|
||||
await self.manager.stop()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
check connection is fully closed.
|
||||
|
@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service):
|
|||
|
||||
return channel_id, flag, message
|
||||
|
||||
@property
|
||||
def _id(self) -> int:
|
||||
return 0 if self.is_initiator else 1
|
||||
|
||||
async def _handle_incoming_message(self) -> None:
|
||||
"""
|
||||
Read and handle a new incoming message.
|
||||
|
|
|
@ -156,26 +156,22 @@ class MplexStream(IMuxedStream):
|
|||
if self.event_local_closed.is_set():
|
||||
return
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
|
||||
flag = (
|
||||
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
||||
)
|
||||
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
|
||||
_is_remote_closed: bool
|
||||
async with self.close_lock:
|
||||
self.event_local_closed.set()
|
||||
_is_remote_closed = self.event_remote_closed.is_set()
|
||||
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
|
||||
if _is_remote_closed:
|
||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if self.stream_id in self.muxed_conn.streams:
|
||||
del self.muxed_conn.streams[self.stream_id]
|
||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""closes both ends of the stream tells this remote side to hang up."""
|
||||
|
|
|
@ -69,9 +69,8 @@ async def raw_conn_factory(
|
|||
tcp_transport = TCP()
|
||||
listener = tcp_transport.create_listener(tcp_stream_handler)
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
listening_maddr = listener.multiaddrs[0]
|
||||
listening_maddr = listener.get_addrs()[0]
|
||||
conn_0 = await tcp_transport.dial(listening_maddr)
|
||||
print("raw_conn_factory")
|
||||
yield conn_0, conn_1
|
||||
|
||||
|
||||
|
|
|
@ -39,3 +39,6 @@ def create_echo_stream_handler(
|
|||
await stream.write(resp.encode())
|
||||
|
||||
return echo_stream_handler
|
||||
|
||||
|
||||
# TODO: Service `external_api`
|
||||
|
|
|
@ -22,3 +22,7 @@ class IListener(ABC):
|
|||
|
||||
:return: return list of addrs
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
...
|
||||
|
|
|
@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp")
|
|||
|
||||
|
||||
class TCPListener(IListener):
|
||||
multiaddrs: List[Multiaddr]
|
||||
listeners: List[trio.SocketListener]
|
||||
|
||||
def __init__(self, handler_function: THandler) -> None:
|
||||
self.multiaddrs = []
|
||||
self.listeners = []
|
||||
self.handler = handler_function
|
||||
|
||||
# TODO: Get rid of `nursery`?
|
||||
|
@ -50,8 +50,7 @@ class TCPListener(IListener):
|
|||
int(maddr.value_for_protocol("tcp")),
|
||||
maddr.value_for_protocol("ip4"),
|
||||
)
|
||||
socket = listeners[0].socket
|
||||
self.multiaddrs.append(_multiaddr_from_socket(socket))
|
||||
self.listeners.extend(listeners)
|
||||
|
||||
def get_addrs(self) -> Tuple[Multiaddr, ...]:
|
||||
"""
|
||||
|
@ -59,7 +58,14 @@ class TCPListener(IListener):
|
|||
|
||||
:return: return list of addrs
|
||||
"""
|
||||
return tuple(self.multiaddrs)
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket) for listener in self.listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in self.listeners:
|
||||
nursery.start_soon(listener.aclose)
|
||||
|
||||
|
||||
class TCP(ITransport):
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_conn_close(swarm_conn_pair):
|
||||
conn_0, conn_1 = swarm_conn_pair
|
||||
|
||||
assert not conn_0.event_closed.is_set()
|
||||
assert not conn_1.event_closed.is_set()
|
||||
assert not conn_0.is_closed
|
||||
assert not conn_1.is_closed
|
||||
|
||||
await conn_0.close()
|
||||
|
||||
await trio.sleep(0.01)
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
await conn_0.manager.wait_finished()
|
||||
|
||||
assert conn_0.event_closed.is_set()
|
||||
assert conn_1.event_closed.is_set()
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
assert conn_0 not in conn_0.swarm.connections.values()
|
||||
assert conn_1 not in conn_1.swarm.connections.values()
|
||||
|
||||
|
|
|
@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair):
|
|||
|
||||
assert len(conn_0.streams) == 0
|
||||
assert len(conn_1.streams) == 0
|
||||
assert not conn_0.event_shutting_down.is_set()
|
||||
assert not conn_1.event_shutting_down.is_set()
|
||||
assert not conn_0.event_closed.is_set()
|
||||
assert not conn_1.event_closed.is_set()
|
||||
|
||||
# Test: Open a stream, and both side get 1 more stream.
|
||||
stream_0 = await conn_0.open_stream()
|
||||
|
@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair):
|
|||
# Sleep for a while for both side to handle `close`.
|
||||
await trio.sleep(0.01)
|
||||
# Test: Both side is closed.
|
||||
assert conn_0.event_shutting_down.is_set()
|
||||
assert conn_0.event_closed.is_set()
|
||||
assert conn_1.event_shutting_down.is_set()
|
||||
assert conn_1.event_closed.is_set()
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
# Test: All streams should have been closed.
|
||||
assert stream_0.event_remote_closed.is_set()
|
||||
assert stream_0.event_reset.is_set()
|
||||
|
|
|
@ -38,8 +38,9 @@ async def test_tcp_dial(nursery):
|
|||
|
||||
listener = transport.create_listener(handler)
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
assert len(listener.multiaddrs) == 1
|
||||
listen_addr = listener.multiaddrs[0]
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
listen_addr = addrs[0]
|
||||
raw_conn = await transport.dial(listen_addr)
|
||||
|
||||
data = b"123"
|
||||
|
|
Loading…
Reference in New Issue
Block a user