Refine Mplex.close and SwarmConn.close

Ensure `close` cleans up things and cancel the service finally.
This commit is contained in:
mhchia 2019-12-17 15:50:55 +08:00
parent d847e78a83
commit fb0519129d
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 71 additions and 51 deletions

View File

@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service):
self.streams = set() self.streams = set()
self.event_closed = trio.Event() self.event_closed = trio.Event()
@property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def close(self) -> None: async def close(self) -> None:
if self.event_closed.is_set(): if self.event_closed.is_set():
return return
self.event_closed.set() self.event_closed.set()
await self._cleanup()
# Cancel service
await self.manager.stop()
async def _cleanup(self) -> None:
self.swarm.remove_conn(self) self.swarm.remove_conn(self)
await self.muxed_conn.close() await self.muxed_conn.close()
@ -51,16 +60,16 @@ class SwarmConn(INetConn, Service):
while self.manager.is_running: while self.manager.is_running:
try: try:
stream = await self.muxed_conn.accept_stream() stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
except MuxedConnUnavailable:
break
self.manager.run_task(self._handle_muxed_stream, stream) self.manager.run_task(self._handle_muxed_stream, stream)
await self.close() await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
try: try:
await self.swarm.common_stream_handler(net_stream) await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions # TODO: More exact exceptions
@ -69,11 +78,6 @@ class SwarmConn(INetConn, Service):
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream) self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self._call_stream_handler(net_stream)
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream) self.streams.add(net_stream)
@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service):
await self.swarm.notify_disconnected(self) await self.swarm.notify_disconnected(self)
async def run(self) -> None: async def run(self) -> None:
await self._handle_new_streams() self.manager.run_task(self._handle_new_streams)
await self.manager.wait_finished()
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()

View File

@ -44,6 +44,7 @@ class Swarm(INetwork, Service):
common_stream_handler: Optional[StreamHandlerFn] common_stream_handler: Optional[StreamHandlerFn]
notifees: List[INotifee] notifees: List[INotifee]
event_closed: trio.Event
def __init__( def __init__(
self, self,
@ -62,6 +63,8 @@ class Swarm(INetwork, Service):
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
self.event_closed = trio.Event()
self.common_stream_handler = None self.common_stream_handler = None
async def run(self) -> None: async def run(self) -> None:
@ -227,10 +230,19 @@ class Swarm(INetwork, Service):
return False return False
async def close(self) -> None: async def close(self) -> None:
# TODO: Prevent from new listeners and conns being added. if self.event_closed.is_set():
return
self.event_closed.set()
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
async with trio.open_nursery() as nursery:
for conn in self.connections.values():
nursery.start_soon(conn.close)
async with trio.open_nursery() as nursery:
for listener in self.listeners.values():
nursery.start_soon(listener.close)
# Cancel tasks
await self.manager.stop() await self.manager.stop()
await self.manager.wait_finished()
logger.debug("swarm successfully closed") logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
@ -270,8 +282,6 @@ class Swarm(INetwork, Service):
# Notifee # Notifee
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
def register_notifee(self, notifee: INotifee) -> None: def register_notifee(self, notifee: INotifee) -> None:
""" """
:param notifee: object implementing Notifee interface :param notifee: object implementing Notifee interface

View File

@ -58,11 +58,11 @@ class TopicValidator(NamedTuple):
# TODO: Add interface for Pubsub # TODO: Add interface for Pubsub
class BasePubsub(ABC): class IPubsub(ABC):
pass pass
class Pubsub(BasePubsub, Service): class Pubsub(IPubsub, Service):
host: IHost host: IHost

View File

@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI):
async def close(self) -> None: async def close(self) -> None:
"""close connection.""" """close connection."""
@property
@abstractmethod @abstractmethod
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """

View File

@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service):
await self.secured_conn.close() await self.secured_conn.close()
# Blocked until `close` is finally set. # Blocked until `close` is finally set.
await self.event_closed.wait() await self.event_closed.wait()
await self.manager.stop()
@property
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
check connection is fully closed. check connection is fully closed.
@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service):
return channel_id, flag, message return channel_id, flag, message
@property
def _id(self) -> int:
return 0 if self.is_initiator else 1
async def _handle_incoming_message(self) -> None: async def _handle_incoming_message(self) -> None:
""" """
Read and handle a new incoming message. Read and handle a new incoming message.

View File

@ -156,26 +156,22 @@ class MplexStream(IMuxedStream):
if self.event_local_closed.is_set(): if self.event_local_closed.is_set():
return return
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
flag = ( flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
) )
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id) await self.muxed_conn.send_message(flag, None, self.stream_id)
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
_is_remote_closed: bool _is_remote_closed: bool
async with self.close_lock: async with self.close_lock:
self.event_local_closed.set() self.event_local_closed.set()
_is_remote_closed = self.event_remote_closed.is_set() _is_remote_closed = self.event_remote_closed.is_set()
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
if _is_remote_closed: if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict. # Both sides are closed, we can safely remove the buffer from the dict.
async with self.muxed_conn.streams_lock: async with self.muxed_conn.streams_lock:
if self.stream_id in self.muxed_conn.streams: if self.stream_id in self.muxed_conn.streams:
del self.muxed_conn.streams[self.stream_id] del self.muxed_conn.streams[self.stream_id]
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
async def reset(self) -> None: async def reset(self) -> None:
"""closes both ends of the stream tells this remote side to hang up.""" """closes both ends of the stream tells this remote side to hang up."""

View File

@ -69,9 +69,8 @@ async def raw_conn_factory(
tcp_transport = TCP() tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler) listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery) await listener.listen(LISTEN_MADDR, nursery)
listening_maddr = listener.multiaddrs[0] listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr) conn_0 = await tcp_transport.dial(listening_maddr)
print("raw_conn_factory")
yield conn_0, conn_1 yield conn_0, conn_1

View File

@ -39,3 +39,6 @@ def create_echo_stream_handler(
await stream.write(resp.encode()) await stream.write(resp.encode())
return echo_stream_handler return echo_stream_handler
# TODO: Service `external_api`

View File

@ -22,3 +22,7 @@ class IListener(ABC):
:return: return list of addrs :return: return list of addrs
""" """
@abstractmethod
async def close(self) -> None:
...

View File

@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener): class TCPListener(IListener):
multiaddrs: List[Multiaddr] listeners: List[trio.SocketListener]
def __init__(self, handler_function: THandler) -> None: def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = [] self.listeners = []
self.handler = handler_function self.handler = handler_function
# TODO: Get rid of `nursery`? # TODO: Get rid of `nursery`?
@ -50,8 +50,7 @@ class TCPListener(IListener):
int(maddr.value_for_protocol("tcp")), int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"), maddr.value_for_protocol("ip4"),
) )
socket = listeners[0].socket self.listeners.extend(listeners)
self.multiaddrs.append(_multiaddr_from_socket(socket))
def get_addrs(self) -> Tuple[Multiaddr, ...]: def get_addrs(self) -> Tuple[Multiaddr, ...]:
""" """
@ -59,7 +58,14 @@ class TCPListener(IListener):
:return: return list of addrs :return: return list of addrs
""" """
return tuple(self.multiaddrs) return tuple(
_multiaddr_from_socket(listener.socket) for listener in self.listeners
)
async def close(self) -> None:
async with trio.open_nursery() as nursery:
for listener in self.listeners:
nursery.start_soon(listener.aclose)
class TCP(ITransport): class TCP(ITransport):

View File

@ -1,20 +1,23 @@
import pytest import pytest
import trio import trio
from trio.testing import wait_all_tasks_blocked
@pytest.mark.trio @pytest.mark.trio
async def test_swarm_conn_close(swarm_conn_pair): async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set() assert not conn_0.is_closed
assert not conn_1.event_closed.is_set() assert not conn_1.is_closed
await conn_0.close() await conn_0.close()
await trio.sleep(0.01) await trio.sleep(0.1)
await wait_all_tasks_blocked()
await conn_0.manager.wait_finished()
assert conn_0.event_closed.is_set() assert conn_0.is_closed
assert conn_1.event_closed.is_set() assert conn_1.is_closed
assert conn_0 not in conn_0.swarm.connections.values() assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values() assert conn_1 not in conn_1.swarm.connections.values()

View File

@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair):
assert len(conn_0.streams) == 0 assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0 assert len(conn_1.streams) == 0
assert not conn_0.event_shutting_down.is_set()
assert not conn_1.event_shutting_down.is_set()
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
# Test: Open a stream, and both side get 1 more stream. # Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream() stream_0 = await conn_0.open_stream()
@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair):
# Sleep for a while for both side to handle `close`. # Sleep for a while for both side to handle `close`.
await trio.sleep(0.01) await trio.sleep(0.01)
# Test: Both side is closed. # Test: Both side is closed.
assert conn_0.event_shutting_down.is_set() assert conn_0.is_closed
assert conn_0.event_closed.is_set() assert conn_1.is_closed
assert conn_1.event_shutting_down.is_set()
assert conn_1.event_closed.is_set()
# Test: All streams should have been closed. # Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set() assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set() assert stream_0.event_reset.is_set()

View File

@ -38,8 +38,9 @@ async def test_tcp_dial(nursery):
listener = transport.create_listener(handler) listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery) await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.multiaddrs) == 1 addrs = listener.get_addrs()
listen_addr = listener.multiaddrs[0] assert len(addrs) == 1
listen_addr = addrs[0]
raw_conn = await transport.dial(listen_addr) raw_conn = await transport.dial(listen_addr)
data = b"123" data = b"123"