Track tasks created in pubsub and add close()

This commit is contained in:
NIC619 2019-11-05 15:22:31 +08:00
parent 93ef36bd86
commit 5dfa29a0df
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17

View File

@ -79,6 +79,8 @@ class Pubsub:
# TODO: Be sure it is increased atomically everytime. # TODO: Be sure it is increased atomically everytime.
counter: int # uint64 counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
def __init__( def __init__(
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
) -> None: ) -> None:
@ -139,9 +141,10 @@ class Pubsub:
self.counter = time.time_ns() self.counter = time.time_ns()
self._tasks = []
# Call handle peer to keep waiting for updates to peer queue # Call handle peer to keep waiting for updates to peer queue
asyncio.ensure_future(self.handle_peer_queue()) self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
asyncio.ensure_future(self.handle_dead_peer_queue()) self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
def get_hello_packet(self) -> rpc_pb2.RPC: def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to """Generate subscription message with all topics we are subscribed to
@ -174,7 +177,7 @@ class Pubsub:
logger.debug( logger.debug(
"received `publish` message %s from peer %s", msg, peer_id "received `publish` message %s from peer %s", msg, peer_id
) )
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)) self._tasks.append(asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)))
if rpc_incoming.subscriptions: if rpc_incoming.subscriptions:
# deal with RPC.subscriptions # deal with RPC.subscriptions
@ -305,7 +308,7 @@ class Pubsub:
while True: while True:
peer_id: ID = await self.peer_queue.get() peer_id: ID = await self.peer_queue.get()
# Add Peer # Add Peer
asyncio.ensure_future(self._handle_new_peer(peer_id)) self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
async def handle_dead_peer_queue(self) -> None: async def handle_dead_peer_queue(self) -> None:
""" """
@ -537,3 +540,11 @@ class Pubsub:
if not self.my_topics: if not self.my_topics:
return False return False
return any(topic in self.my_topics for topic in msg.topicIDs) return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass