diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 54de625..bbfcd41 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -79,6 +79,8 @@ class Pubsub: # TODO: Be sure it is increased atomically everytime. counter: int # uint64 + _tasks: List["asyncio.Future[Any]"] + def __init__( self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None ) -> None: @@ -139,9 +141,10 @@ class Pubsub: self.counter = time.time_ns() + self._tasks = [] # Call handle peer to keep waiting for updates to peer queue - asyncio.ensure_future(self.handle_peer_queue()) - asyncio.ensure_future(self.handle_dead_peer_queue()) + self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) + self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) def get_hello_packet(self) -> rpc_pb2.RPC: """Generate subscription message with all topics we are subscribed to @@ -174,7 +177,7 @@ class Pubsub: logger.debug( "received `publish` message %s from peer %s", msg, peer_id ) - asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)) + self._tasks.append(asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -305,7 +308,7 @@ class Pubsub: while True: peer_id: ID = await self.peer_queue.get() # Add Peer - asyncio.ensure_future(self._handle_new_peer(peer_id)) + self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) async def handle_dead_peer_queue(self) -> None: """ @@ -537,3 +540,11 @@ class Pubsub: if not self.my_topics: return False return any(topic in self.my_topics for topic in msg.topicIDs) + + async def close(self) -> None: + for task in self._tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass