Track tasks created in pubsub and add close()

This commit is contained in:
NIC619 2019-11-05 15:22:31 +08:00
parent 93ef36bd86
commit 5dfa29a0df
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17

View File

@ -79,6 +79,8 @@ class Pubsub:
# TODO: Be sure it is increased atomically everytime.
counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
def __init__(
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
) -> None:
@ -139,9 +141,10 @@ class Pubsub:
self.counter = time.time_ns()
self._tasks = []
# Call handle peer to keep waiting for updates to peer queue
asyncio.ensure_future(self.handle_peer_queue())
asyncio.ensure_future(self.handle_dead_peer_queue())
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to
@ -174,7 +177,7 @@ class Pubsub:
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
self._tasks.append(asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)))
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
@ -305,7 +308,7 @@ class Pubsub:
while True:
peer_id: ID = await self.peer_queue.get()
# Add Peer
asyncio.ensure_future(self._handle_new_peer(peer_id))
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
async def handle_dead_peer_queue(self) -> None:
"""
@ -537,3 +540,11 @@ class Pubsub:
if not self.my_topics:
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass