Merge pull request #340 from NIC619/fix_pubsub_stream_to_disconnected_peer
Register for disconnected event notification by pubsub
This commit is contained in:
commit
a5c3b8dec2
|
@ -43,6 +43,9 @@ class SwarmConn(INetConn):
|
||||||
# We *could* optimize this but it really isn't worth it.
|
# We *could* optimize this but it really isn't worth it.
|
||||||
for stream in self.streams:
|
for stream in self.streams:
|
||||||
await stream.reset()
|
await stream.reset()
|
||||||
|
# Force context switch for stream handlers to process the stream reset event we just emit
|
||||||
|
# before we cancel the stream handler tasks.
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
for task in self._tasks:
|
for task in self._tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
|
@ -248,8 +248,8 @@ class Swarm(INetwork):
|
||||||
# TODO: Should be changed to close multisple connections,
|
# TODO: Should be changed to close multisple connections,
|
||||||
# if we have several connections per peer in the future.
|
# if we have several connections per peer in the future.
|
||||||
connection = self.connections[peer_id]
|
connection = self.connections[peer_id]
|
||||||
# NOTE: `connection.close` performs `del self.connections[peer_id]` for us,
|
# NOTE: `connection.close` will perform `del self.connections[peer_id]`
|
||||||
# so we don't need to remove the entry here.
|
# and `notify_disconnected` for us.
|
||||||
await connection.close()
|
await connection.close()
|
||||||
|
|
||||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||||
|
|
|
@ -32,6 +32,7 @@ from .validators import signature_validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .pubsub_router_interface import IPubsubRouter # noqa: F401
|
from .pubsub_router_interface import IPubsubRouter # noqa: F401
|
||||||
|
from typing import Any # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.pubsub")
|
logger = logging.getLogger("libp2p.pubsub")
|
||||||
|
@ -60,6 +61,7 @@ class Pubsub:
|
||||||
router: "IPubsubRouter"
|
router: "IPubsubRouter"
|
||||||
|
|
||||||
peer_queue: "asyncio.Queue[ID]"
|
peer_queue: "asyncio.Queue[ID]"
|
||||||
|
dead_peer_queue: "asyncio.Queue[ID]"
|
||||||
|
|
||||||
protocols: List[TProtocol]
|
protocols: List[TProtocol]
|
||||||
|
|
||||||
|
@ -78,6 +80,8 @@ class Pubsub:
|
||||||
# TODO: Be sure it is increased atomically everytime.
|
# TODO: Be sure it is increased atomically everytime.
|
||||||
counter: int # uint64
|
counter: int # uint64
|
||||||
|
|
||||||
|
_tasks: List["asyncio.Future[Any]"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -100,7 +104,10 @@ class Pubsub:
|
||||||
|
|
||||||
# Register a notifee
|
# Register a notifee
|
||||||
self.peer_queue = asyncio.Queue()
|
self.peer_queue = asyncio.Queue()
|
||||||
self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue))
|
self.dead_peer_queue = asyncio.Queue()
|
||||||
|
self.host.get_network().register_notifee(
|
||||||
|
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
|
||||||
|
)
|
||||||
|
|
||||||
# Register stream handlers for each pubsub router protocol to handle
|
# Register stream handlers for each pubsub router protocol to handle
|
||||||
# the pubsub streams opened on those protocols
|
# the pubsub streams opened on those protocols
|
||||||
|
@ -135,8 +142,10 @@ class Pubsub:
|
||||||
|
|
||||||
self.counter = time.time_ns()
|
self.counter = time.time_ns()
|
||||||
|
|
||||||
|
self._tasks = []
|
||||||
# Call handle peer to keep waiting for updates to peer queue
|
# Call handle peer to keep waiting for updates to peer queue
|
||||||
asyncio.ensure_future(self.handle_peer_queue())
|
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
||||||
|
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
||||||
|
|
||||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||||
"""Generate subscription message with all topics we are subscribed to
|
"""Generate subscription message with all topics we are subscribed to
|
||||||
|
@ -158,13 +167,7 @@ class Pubsub:
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
|
||||||
except (ParseError, IncompleteReadError) as error:
|
|
||||||
logger.debug(
|
|
||||||
"read corrupted data from peer %s, error=%s", peer_id, error
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
rpc_incoming.ParseFromString(incoming)
|
rpc_incoming.ParseFromString(incoming)
|
||||||
if rpc_incoming.publish:
|
if rpc_incoming.publish:
|
||||||
|
@ -175,7 +178,11 @@ class Pubsub:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"received `publish` message %s from peer %s", msg, peer_id
|
"received `publish` message %s from peer %s", msg, peer_id
|
||||||
)
|
)
|
||||||
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
|
self._tasks.append(
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self.push_msg(msg_forwarder=peer_id, msg=msg)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if rpc_incoming.subscriptions:
|
if rpc_incoming.subscriptions:
|
||||||
# deal with RPC.subscriptions
|
# deal with RPC.subscriptions
|
||||||
|
@ -247,13 +254,19 @@ class Pubsub:
|
||||||
|
|
||||||
:param stream: newly created stream
|
:param stream: newly created stream
|
||||||
"""
|
"""
|
||||||
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.continuously_read_stream(stream)
|
await self.continuously_read_stream(stream)
|
||||||
except (StreamEOF, StreamReset) as error:
|
except (StreamEOF, StreamReset, ParseError, IncompleteReadError) as error:
|
||||||
logger.debug("fail to read from stream, error=%s", error)
|
logger.debug(
|
||||||
|
"fail to read from peer %s, error=%s,"
|
||||||
|
"closing the stream and remove the peer from record",
|
||||||
|
peer_id,
|
||||||
|
error,
|
||||||
|
)
|
||||||
await stream.reset()
|
await stream.reset()
|
||||||
# TODO: what to do when the stream is terminated?
|
self._handle_dead_peer(peer_id)
|
||||||
# disconnect the peer?
|
|
||||||
|
|
||||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -277,6 +290,19 @@ class Pubsub:
|
||||||
|
|
||||||
logger.debug("added new peer %s", peer_id)
|
logger.debug("added new peer %s", peer_id)
|
||||||
|
|
||||||
|
def _handle_dead_peer(self, peer_id: ID) -> None:
|
||||||
|
if peer_id not in self.peers:
|
||||||
|
return
|
||||||
|
del self.peers[peer_id]
|
||||||
|
|
||||||
|
for topic in self.peer_topics:
|
||||||
|
if peer_id in self.peer_topics[topic]:
|
||||||
|
self.peer_topics[topic].remove(peer_id)
|
||||||
|
|
||||||
|
self.router.remove_peer(peer_id)
|
||||||
|
|
||||||
|
logger.debug("removed dead peer %s", peer_id)
|
||||||
|
|
||||||
async def handle_peer_queue(self) -> None:
|
async def handle_peer_queue(self) -> None:
|
||||||
"""
|
"""
|
||||||
Continuously read from peer queue and each time a new peer is found,
|
Continuously read from peer queue and each time a new peer is found,
|
||||||
|
@ -285,14 +311,17 @@ class Pubsub:
|
||||||
pubsub protocols we support
|
pubsub protocols we support
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
peer_id: ID = await self.peer_queue.get()
|
peer_id: ID = await self.peer_queue.get()
|
||||||
|
|
||||||
# Add Peer
|
# Add Peer
|
||||||
|
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
|
||||||
|
|
||||||
asyncio.ensure_future(self._handle_new_peer(peer_id))
|
async def handle_dead_peer_queue(self) -> None:
|
||||||
# Force context switch
|
"""Continuously read from dead peer queue and close the stream between
|
||||||
await asyncio.sleep(0)
|
that peer and remove peer info from pubsub and pubsub router."""
|
||||||
|
while True:
|
||||||
|
peer_id: ID = await self.dead_peer_queue.get()
|
||||||
|
# Remove Peer
|
||||||
|
self._handle_dead_peer(peer_id)
|
||||||
|
|
||||||
def handle_subscription(
|
def handle_subscription(
|
||||||
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
||||||
|
@ -514,3 +543,11 @@ class Pubsub:
|
||||||
if not self.my_topics:
|
if not self.my_topics:
|
||||||
return False
|
return False
|
||||||
return any(topic in self.my_topics for topic in msg.topicIDs)
|
return any(topic in self.my_topics for topic in msg.topicIDs)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
for task in self._tasks:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
|
@ -15,13 +15,21 @@ if TYPE_CHECKING:
|
||||||
class PubsubNotifee(INotifee):
|
class PubsubNotifee(INotifee):
|
||||||
|
|
||||||
initiator_peers_queue: "asyncio.Queue[ID]"
|
initiator_peers_queue: "asyncio.Queue[ID]"
|
||||||
|
dead_peers_queue: "asyncio.Queue[ID]"
|
||||||
|
|
||||||
def __init__(self, initiator_peers_queue: "asyncio.Queue[ID]") -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
initiator_peers_queue: "asyncio.Queue[ID]",
|
||||||
|
dead_peers_queue: "asyncio.Queue[ID]",
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
||||||
can process new peers after we connect to them
|
can process new peers after we connect to them
|
||||||
|
:param dead_peers_queue: queue to add dead peers to so that pubsub
|
||||||
|
can process dead peers after we disconnect from each other
|
||||||
"""
|
"""
|
||||||
self.initiator_peers_queue = initiator_peers_queue
|
self.initiator_peers_queue = initiator_peers_queue
|
||||||
|
self.dead_peers_queue = dead_peers_queue
|
||||||
|
|
||||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -41,7 +49,14 @@ class PubsubNotifee(INotifee):
|
||||||
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
|
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
|
||||||
|
|
||||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||||
pass
|
"""
|
||||||
|
Add peer_id to dead_peers_queue, so that pubsub and its router can
|
||||||
|
remove this peer_id and close the stream inbetween.
|
||||||
|
|
||||||
|
:param network: network the connection was opened on
|
||||||
|
:param conn: connection that was opened
|
||||||
|
"""
|
||||||
|
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
|
||||||
|
|
||||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user