Refine Mplex.close and SwarmConn.close

Ensure `close` cleans up things and cancel the service finally.
This commit is contained in:
mhchia 2019-12-17 15:50:55 +08:00
parent d847e78a83
commit fb0519129d
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 71 additions and 51 deletions

View File

@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service):
self.streams = set()
self.event_closed = trio.Event()
@property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def close(self) -> None:
if self.event_closed.is_set():
return
self.event_closed.set()
await self._cleanup()
# Cancel service
await self.manager.stop()
async def _cleanup(self) -> None:
self.swarm.remove_conn(self)
await self.muxed_conn.close()
@ -51,28 +60,23 @@ class SwarmConn(INetConn, Service):
while self.manager.is_running:
try:
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
except MuxedConnUnavailable:
break
self.manager.run_task(self._handle_muxed_stream, stream)
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self._call_stream_handler(net_stream)
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream)
@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service):
await self.swarm.notify_disconnected(self)
async def run(self) -> None:
await self._handle_new_streams()
self.manager.run_task(self._handle_new_streams)
await self.manager.wait_finished()
async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream()

View File

@ -44,6 +44,7 @@ class Swarm(INetwork, Service):
common_stream_handler: Optional[StreamHandlerFn]
notifees: List[INotifee]
event_closed: trio.Event
def __init__(
self,
@ -62,6 +63,8 @@ class Swarm(INetwork, Service):
# Create Notifee array
self.notifees = []
self.event_closed = trio.Event()
self.common_stream_handler = None
async def run(self) -> None:
@ -227,10 +230,19 @@ class Swarm(INetwork, Service):
return False
async def close(self) -> None:
# TODO: Prevent from new listeners and conns being added.
if self.event_closed.is_set():
return
self.event_closed.set()
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
async with trio.open_nursery() as nursery:
for conn in self.connections.values():
nursery.start_soon(conn.close)
async with trio.open_nursery() as nursery:
for listener in self.listeners.values():
nursery.start_soon(listener.close)
# Cancel tasks
await self.manager.stop()
await self.manager.wait_finished()
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:
@ -270,8 +282,6 @@ class Swarm(INetwork, Service):
# Notifee
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
def register_notifee(self, notifee: INotifee) -> None:
"""
:param notifee: object implementing Notifee interface

View File

@ -58,11 +58,11 @@ class TopicValidator(NamedTuple):
# TODO: Add interface for Pubsub
class BasePubsub(ABC):
class IPubsub(ABC):
pass
class Pubsub(BasePubsub, Service):
class Pubsub(IPubsub, Service):
host: IHost

View File

@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI):
async def close(self) -> None:
"""close connection."""
@property
@abstractmethod
def is_closed(self) -> bool:
"""

View File

@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service):
await self.secured_conn.close()
# Blocked until `close` is finally set.
await self.event_closed.wait()
await self.manager.stop()
@property
def is_closed(self) -> bool:
"""
check connection is fully closed.
@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service):
return channel_id, flag, message
@property
def _id(self) -> int:
return 0 if self.is_initiator else 1
async def _handle_incoming_message(self) -> None:
"""
Read and handle a new incoming message.

View File

@ -156,26 +156,22 @@ class MplexStream(IMuxedStream):
if self.event_local_closed.is_set():
return
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
)
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id)
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
_is_remote_closed: bool
async with self.close_lock:
self.event_local_closed.set()
_is_remote_closed = self.event_remote_closed.is_set()
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict.
async with self.muxed_conn.streams_lock:
if self.stream_id in self.muxed_conn.streams:
del self.muxed_conn.streams[self.stream_id]
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
async def reset(self) -> None:
"""closes both ends of the stream tells this remote side to hang up."""

View File

@ -69,9 +69,8 @@ async def raw_conn_factory(
tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery)
listening_maddr = listener.multiaddrs[0]
listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr)
print("raw_conn_factory")
yield conn_0, conn_1

View File

@ -39,3 +39,6 @@ def create_echo_stream_handler(
await stream.write(resp.encode())
return echo_stream_handler
# TODO: Service `external_api`

View File

@ -22,3 +22,7 @@ class IListener(ABC):
:return: return list of addrs
"""
@abstractmethod
async def close(self) -> None:
...

View File

@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener):
multiaddrs: List[Multiaddr]
listeners: List[trio.SocketListener]
def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = []
self.listeners = []
self.handler = handler_function
# TODO: Get rid of `nursery`?
@ -50,8 +50,7 @@ class TCPListener(IListener):
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
)
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket))
self.listeners.extend(listeners)
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
@ -59,7 +58,14 @@ class TCPListener(IListener):
:return: return list of addrs
"""
return tuple(self.multiaddrs)
return tuple(
_multiaddr_from_socket(listener.socket) for listener in self.listeners
)
async def close(self) -> None:
async with trio.open_nursery() as nursery:
for listener in self.listeners:
nursery.start_soon(listener.aclose)
class TCP(ITransport):

View File

@ -1,20 +1,23 @@
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
@pytest.mark.trio
async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
assert not conn_0.is_closed
assert not conn_1.is_closed
await conn_0.close()
await trio.sleep(0.01)
await trio.sleep(0.1)
await wait_all_tasks_blocked()
await conn_0.manager.wait_finished()
assert conn_0.event_closed.is_set()
assert conn_1.event_closed.is_set()
assert conn_0.is_closed
assert conn_1.is_closed
assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values()

View File

@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair):
assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0
assert not conn_0.event_shutting_down.is_set()
assert not conn_1.event_shutting_down.is_set()
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
# Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream()
@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair):
# Sleep for a while for both side to handle `close`.
await trio.sleep(0.01)
# Test: Both side is closed.
assert conn_0.event_shutting_down.is_set()
assert conn_0.event_closed.is_set()
assert conn_1.event_shutting_down.is_set()
assert conn_1.event_closed.is_set()
assert conn_0.is_closed
assert conn_1.is_closed
# Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set()

View File

@ -38,8 +38,9 @@ async def test_tcp_dial(nursery):
listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.multiaddrs) == 1
listen_addr = listener.multiaddrs[0]
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]
raw_conn = await transport.dial(listen_addr)
data = b"123"