Refine Mplex.close
and SwarmConn.close
Ensure `close` cleans up things and cancel the service finally.
This commit is contained in:
parent
d847e78a83
commit
fb0519129d
|
@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service):
|
||||||
self.streams = set()
|
self.streams = set()
|
||||||
self.event_closed = trio.Event()
|
self.event_closed = trio.Event()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
return self.event_closed.is_set()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self.event_closed.is_set():
|
if self.event_closed.is_set():
|
||||||
return
|
return
|
||||||
self.event_closed.set()
|
self.event_closed.set()
|
||||||
|
await self._cleanup()
|
||||||
|
# Cancel service
|
||||||
|
await self.manager.stop()
|
||||||
|
|
||||||
|
async def _cleanup(self) -> None:
|
||||||
self.swarm.remove_conn(self)
|
self.swarm.remove_conn(self)
|
||||||
|
|
||||||
await self.muxed_conn.close()
|
await self.muxed_conn.close()
|
||||||
|
@ -51,16 +60,16 @@ class SwarmConn(INetConn, Service):
|
||||||
while self.manager.is_running:
|
while self.manager.is_running:
|
||||||
try:
|
try:
|
||||||
stream = await self.muxed_conn.accept_stream()
|
stream = await self.muxed_conn.accept_stream()
|
||||||
except MuxedConnUnavailable:
|
|
||||||
# If there is anything wrong in the MuxedConn,
|
|
||||||
# we should break the loop and close the connection.
|
|
||||||
break
|
|
||||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||||
|
except MuxedConnUnavailable:
|
||||||
|
break
|
||||||
self.manager.run_task(self._handle_muxed_stream, stream)
|
self.manager.run_task(self._handle_muxed_stream, stream)
|
||||||
|
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||||
|
net_stream = await self._add_stream(muxed_stream)
|
||||||
|
if self.swarm.common_stream_handler is not None:
|
||||||
try:
|
try:
|
||||||
await self.swarm.common_stream_handler(net_stream)
|
await self.swarm.common_stream_handler(net_stream)
|
||||||
# TODO: More exact exceptions
|
# TODO: More exact exceptions
|
||||||
|
@ -69,11 +78,6 @@ class SwarmConn(INetConn, Service):
|
||||||
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
|
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
|
||||||
self.remove_stream(net_stream)
|
self.remove_stream(net_stream)
|
||||||
|
|
||||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
|
||||||
net_stream = await self._add_stream(muxed_stream)
|
|
||||||
if self.swarm.common_stream_handler is not None:
|
|
||||||
await self._call_stream_handler(net_stream)
|
|
||||||
|
|
||||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||||
net_stream = NetStream(muxed_stream)
|
net_stream = NetStream(muxed_stream)
|
||||||
self.streams.add(net_stream)
|
self.streams.add(net_stream)
|
||||||
|
@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service):
|
||||||
await self.swarm.notify_disconnected(self)
|
await self.swarm.notify_disconnected(self)
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
await self._handle_new_streams()
|
self.manager.run_task(self._handle_new_streams)
|
||||||
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
async def new_stream(self) -> NetStream:
|
async def new_stream(self) -> NetStream:
|
||||||
muxed_stream = await self.muxed_conn.open_stream()
|
muxed_stream = await self.muxed_conn.open_stream()
|
||||||
|
|
|
@ -44,6 +44,7 @@ class Swarm(INetwork, Service):
|
||||||
common_stream_handler: Optional[StreamHandlerFn]
|
common_stream_handler: Optional[StreamHandlerFn]
|
||||||
|
|
||||||
notifees: List[INotifee]
|
notifees: List[INotifee]
|
||||||
|
event_closed: trio.Event
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -62,6 +63,8 @@ class Swarm(INetwork, Service):
|
||||||
# Create Notifee array
|
# Create Notifee array
|
||||||
self.notifees = []
|
self.notifees = []
|
||||||
|
|
||||||
|
self.event_closed = trio.Event()
|
||||||
|
|
||||||
self.common_stream_handler = None
|
self.common_stream_handler = None
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
|
@ -227,10 +230,19 @@ class Swarm(INetwork, Service):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
# TODO: Prevent from new listeners and conns being added.
|
if self.event_closed.is_set():
|
||||||
|
return
|
||||||
|
self.event_closed.set()
|
||||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for conn in self.connections.values():
|
||||||
|
nursery.start_soon(conn.close)
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for listener in self.listeners.values():
|
||||||
|
nursery.start_soon(listener.close)
|
||||||
|
|
||||||
|
# Cancel tasks
|
||||||
await self.manager.stop()
|
await self.manager.stop()
|
||||||
await self.manager.wait_finished()
|
|
||||||
logger.debug("swarm successfully closed")
|
logger.debug("swarm successfully closed")
|
||||||
|
|
||||||
async def close_peer(self, peer_id: ID) -> None:
|
async def close_peer(self, peer_id: ID) -> None:
|
||||||
|
@ -270,8 +282,6 @@ class Swarm(INetwork, Service):
|
||||||
|
|
||||||
# Notifee
|
# Notifee
|
||||||
|
|
||||||
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
|
|
||||||
|
|
||||||
def register_notifee(self, notifee: INotifee) -> None:
|
def register_notifee(self, notifee: INotifee) -> None:
|
||||||
"""
|
"""
|
||||||
:param notifee: object implementing Notifee interface
|
:param notifee: object implementing Notifee interface
|
||||||
|
|
|
@ -58,11 +58,11 @@ class TopicValidator(NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add interface for Pubsub
|
# TODO: Add interface for Pubsub
|
||||||
class BasePubsub(ABC):
|
class IPubsub(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Pubsub(BasePubsub, Service):
|
class Pubsub(IPubsub, Service):
|
||||||
|
|
||||||
host: IHost
|
host: IHost
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI):
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""close connection."""
|
"""close connection."""
|
||||||
|
|
||||||
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service):
|
||||||
await self.secured_conn.close()
|
await self.secured_conn.close()
|
||||||
# Blocked until `close` is finally set.
|
# Blocked until `close` is finally set.
|
||||||
await self.event_closed.wait()
|
await self.event_closed.wait()
|
||||||
|
await self.manager.stop()
|
||||||
|
|
||||||
|
@property
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
check connection is fully closed.
|
check connection is fully closed.
|
||||||
|
@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service):
|
||||||
|
|
||||||
return channel_id, flag, message
|
return channel_id, flag, message
|
||||||
|
|
||||||
@property
|
|
||||||
def _id(self) -> int:
|
|
||||||
return 0 if self.is_initiator else 1
|
|
||||||
|
|
||||||
async def _handle_incoming_message(self) -> None:
|
async def _handle_incoming_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
Read and handle a new incoming message.
|
Read and handle a new incoming message.
|
||||||
|
|
|
@ -156,26 +156,22 @@ class MplexStream(IMuxedStream):
|
||||||
if self.event_local_closed.is_set():
|
if self.event_local_closed.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=0")
|
|
||||||
flag = (
|
flag = (
|
||||||
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
|
||||||
)
|
)
|
||||||
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
|
||||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||||
|
|
||||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=1")
|
|
||||||
_is_remote_closed: bool
|
_is_remote_closed: bool
|
||||||
async with self.close_lock:
|
async with self.close_lock:
|
||||||
self.event_local_closed.set()
|
self.event_local_closed.set()
|
||||||
_is_remote_closed = self.event_remote_closed.is_set()
|
_is_remote_closed = self.event_remote_closed.is_set()
|
||||||
|
|
||||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=2")
|
|
||||||
if _is_remote_closed:
|
if _is_remote_closed:
|
||||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||||
async with self.muxed_conn.streams_lock:
|
async with self.muxed_conn.streams_lock:
|
||||||
if self.stream_id in self.muxed_conn.streams:
|
if self.stream_id in self.muxed_conn.streams:
|
||||||
del self.muxed_conn.streams[self.stream_id]
|
del self.muxed_conn.streams[self.stream_id]
|
||||||
print(f"!@# stream.close: {self.muxed_conn._id}: step=3")
|
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""closes both ends of the stream tells this remote side to hang up."""
|
"""closes both ends of the stream tells this remote side to hang up."""
|
||||||
|
|
|
@ -69,9 +69,8 @@ async def raw_conn_factory(
|
||||||
tcp_transport = TCP()
|
tcp_transport = TCP()
|
||||||
listener = tcp_transport.create_listener(tcp_stream_handler)
|
listener = tcp_transport.create_listener(tcp_stream_handler)
|
||||||
await listener.listen(LISTEN_MADDR, nursery)
|
await listener.listen(LISTEN_MADDR, nursery)
|
||||||
listening_maddr = listener.multiaddrs[0]
|
listening_maddr = listener.get_addrs()[0]
|
||||||
conn_0 = await tcp_transport.dial(listening_maddr)
|
conn_0 = await tcp_transport.dial(listening_maddr)
|
||||||
print("raw_conn_factory")
|
|
||||||
yield conn_0, conn_1
|
yield conn_0, conn_1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,3 +39,6 @@ def create_echo_stream_handler(
|
||||||
await stream.write(resp.encode())
|
await stream.write(resp.encode())
|
||||||
|
|
||||||
return echo_stream_handler
|
return echo_stream_handler
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Service `external_api`
|
||||||
|
|
|
@ -22,3 +22,7 @@ class IListener(ABC):
|
||||||
|
|
||||||
:return: return list of addrs
|
:return: return list of addrs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
...
|
||||||
|
|
|
@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp")
|
||||||
|
|
||||||
|
|
||||||
class TCPListener(IListener):
|
class TCPListener(IListener):
|
||||||
multiaddrs: List[Multiaddr]
|
listeners: List[trio.SocketListener]
|
||||||
|
|
||||||
def __init__(self, handler_function: THandler) -> None:
|
def __init__(self, handler_function: THandler) -> None:
|
||||||
self.multiaddrs = []
|
self.listeners = []
|
||||||
self.handler = handler_function
|
self.handler = handler_function
|
||||||
|
|
||||||
# TODO: Get rid of `nursery`?
|
# TODO: Get rid of `nursery`?
|
||||||
|
@ -50,8 +50,7 @@ class TCPListener(IListener):
|
||||||
int(maddr.value_for_protocol("tcp")),
|
int(maddr.value_for_protocol("tcp")),
|
||||||
maddr.value_for_protocol("ip4"),
|
maddr.value_for_protocol("ip4"),
|
||||||
)
|
)
|
||||||
socket = listeners[0].socket
|
self.listeners.extend(listeners)
|
||||||
self.multiaddrs.append(_multiaddr_from_socket(socket))
|
|
||||||
|
|
||||||
def get_addrs(self) -> Tuple[Multiaddr, ...]:
|
def get_addrs(self) -> Tuple[Multiaddr, ...]:
|
||||||
"""
|
"""
|
||||||
|
@ -59,7 +58,14 @@ class TCPListener(IListener):
|
||||||
|
|
||||||
:return: return list of addrs
|
:return: return list of addrs
|
||||||
"""
|
"""
|
||||||
return tuple(self.multiaddrs)
|
return tuple(
|
||||||
|
_multiaddr_from_socket(listener.socket) for listener in self.listeners
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for listener in self.listeners:
|
||||||
|
nursery.start_soon(listener.aclose)
|
||||||
|
|
||||||
|
|
||||||
class TCP(ITransport):
|
class TCP(ITransport):
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
from trio.testing import wait_all_tasks_blocked
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_swarm_conn_close(swarm_conn_pair):
|
async def test_swarm_conn_close(swarm_conn_pair):
|
||||||
conn_0, conn_1 = swarm_conn_pair
|
conn_0, conn_1 = swarm_conn_pair
|
||||||
|
|
||||||
assert not conn_0.event_closed.is_set()
|
assert not conn_0.is_closed
|
||||||
assert not conn_1.event_closed.is_set()
|
assert not conn_1.is_closed
|
||||||
|
|
||||||
await conn_0.close()
|
await conn_0.close()
|
||||||
|
|
||||||
await trio.sleep(0.01)
|
await trio.sleep(0.1)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
|
await conn_0.manager.wait_finished()
|
||||||
|
|
||||||
assert conn_0.event_closed.is_set()
|
assert conn_0.is_closed
|
||||||
assert conn_1.event_closed.is_set()
|
assert conn_1.is_closed
|
||||||
assert conn_0 not in conn_0.swarm.connections.values()
|
assert conn_0 not in conn_0.swarm.connections.values()
|
||||||
assert conn_1 not in conn_1.swarm.connections.values()
|
assert conn_1 not in conn_1.swarm.connections.values()
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair):
|
||||||
|
|
||||||
assert len(conn_0.streams) == 0
|
assert len(conn_0.streams) == 0
|
||||||
assert len(conn_1.streams) == 0
|
assert len(conn_1.streams) == 0
|
||||||
assert not conn_0.event_shutting_down.is_set()
|
|
||||||
assert not conn_1.event_shutting_down.is_set()
|
|
||||||
assert not conn_0.event_closed.is_set()
|
|
||||||
assert not conn_1.event_closed.is_set()
|
|
||||||
|
|
||||||
# Test: Open a stream, and both side get 1 more stream.
|
# Test: Open a stream, and both side get 1 more stream.
|
||||||
stream_0 = await conn_0.open_stream()
|
stream_0 = await conn_0.open_stream()
|
||||||
|
@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair):
|
||||||
# Sleep for a while for both side to handle `close`.
|
# Sleep for a while for both side to handle `close`.
|
||||||
await trio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# Test: Both side is closed.
|
# Test: Both side is closed.
|
||||||
assert conn_0.event_shutting_down.is_set()
|
assert conn_0.is_closed
|
||||||
assert conn_0.event_closed.is_set()
|
assert conn_1.is_closed
|
||||||
assert conn_1.event_shutting_down.is_set()
|
|
||||||
assert conn_1.event_closed.is_set()
|
|
||||||
# Test: All streams should have been closed.
|
# Test: All streams should have been closed.
|
||||||
assert stream_0.event_remote_closed.is_set()
|
assert stream_0.event_remote_closed.is_set()
|
||||||
assert stream_0.event_reset.is_set()
|
assert stream_0.event_reset.is_set()
|
||||||
|
|
|
@ -38,8 +38,9 @@ async def test_tcp_dial(nursery):
|
||||||
|
|
||||||
listener = transport.create_listener(handler)
|
listener = transport.create_listener(handler)
|
||||||
await listener.listen(LISTEN_MADDR, nursery)
|
await listener.listen(LISTEN_MADDR, nursery)
|
||||||
assert len(listener.multiaddrs) == 1
|
addrs = listener.get_addrs()
|
||||||
listen_addr = listener.multiaddrs[0]
|
assert len(addrs) == 1
|
||||||
|
listen_addr = addrs[0]
|
||||||
raw_conn = await transport.dial(listen_addr)
|
raw_conn = await transport.dial(listen_addr)
|
||||||
|
|
||||||
data = b"123"
|
data = b"123"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user