Add type hints to gossipsub.py
This commit is contained in:
parent
8eb6a230ff
commit
b920955db6
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import random
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
MutableSet,
|
||||
|
@ -16,6 +18,7 @@ from libp2p.peer.id import (
|
|||
|
||||
from .mcache import MessageCache
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
|
||||
|
@ -24,11 +27,43 @@ class GossipSub(IPubsubRouter):
|
|||
# pylint: disable=too-many-instance-attributes
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
def __init__(self, protocols, degree, degree_low, degree_high, time_to_live, gossip_window=3,
|
||||
gossip_history=5, heartbeat_interval=120):
|
||||
protocols: Sequence[str]
|
||||
pubsub: Pubsub
|
||||
|
||||
degree: int
|
||||
degree_high: int
|
||||
degree_low: int
|
||||
|
||||
time_to_live: int
|
||||
|
||||
# FIXME: Should be changed to `Dict[str, List[ID]]`
|
||||
mesh: Dict[str, List[str]]
|
||||
# FIXME: Should be changed to `Dict[str, List[ID]]`
|
||||
fanout: Dict[str, List[str]]
|
||||
|
||||
time_since_last_publish: Dict[str, int]
|
||||
|
||||
#FIXME: Should be changed to List[ID]
|
||||
peers_gossipsub: List[str]
|
||||
#FIXME: Should be changed to List[ID]
|
||||
peers_floodsub: List[str]
|
||||
|
||||
mcache: MessageCache
|
||||
|
||||
heartbeat_interval: int
|
||||
|
||||
def __init__(self,
|
||||
protocols: Sequence[str],
|
||||
degree: int,
|
||||
degree_low: int,
|
||||
degree_high: int,
|
||||
time_to_live: int,
|
||||
gossip_window: int=3,
|
||||
gossip_history: int=5,
|
||||
heartbeat_interval: int=120) -> None:
|
||||
# pylint: disable=too-many-arguments
|
||||
self.protocols = protocols
|
||||
self.pubsub = None
|
||||
self.protocols: List[str] = protocols
|
||||
self.pubsub: Pubsub = None
|
||||
|
||||
# Store target degree, upper degree bound, and lower degree bound
|
||||
self.degree = degree
|
||||
|
@ -36,7 +71,7 @@ class GossipSub(IPubsubRouter):
|
|||
self.degree_high = degree_high
|
||||
|
||||
# Store time to live (for topics in fanout)
|
||||
self.time_to_live = time_to_live
|
||||
self.time_to_live: int = time_to_live
|
||||
|
||||
# Create topic --> list of peers mappings
|
||||
self.mesh = {}
|
||||
|
@ -56,13 +91,13 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
# Interface functions
|
||||
|
||||
def get_protocols(self):
|
||||
def get_protocols(self) -> List:
|
||||
"""
|
||||
:return: the list of protocols supported by the router
|
||||
"""
|
||||
return self.protocols
|
||||
|
||||
def attach(self, pubsub):
|
||||
def attach(self, pubsub: Pubsub) -> None:
|
||||
"""
|
||||
Attach is invoked by the PubSub constructor to attach the router to a
|
||||
freshly initialized PubSub instance.
|
||||
|
@ -74,10 +109,11 @@ class GossipSub(IPubsubRouter):
|
|||
# TODO: Start after delay
|
||||
asyncio.ensure_future(self.heartbeat())
|
||||
|
||||
def add_peer(self, peer_id, protocol_id):
|
||||
def add_peer(self, peer_id: ID, protocol_id: str):
|
||||
"""
|
||||
Notifies the router that a new peer has been connected
|
||||
:param peer_id: id of peer to add
|
||||
:param protocol_id: router protocol the peer speaks, e.g., floodsub, gossipsub
|
||||
"""
|
||||
|
||||
# Add peer to the correct peer list
|
||||
|
@ -88,7 +124,7 @@ class GossipSub(IPubsubRouter):
|
|||
elif peer_type == "flood":
|
||||
self.peers_floodsub.append(peer_id_str)
|
||||
|
||||
def remove_peer(self, peer_id):
|
||||
def remove_peer(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Notifies the router that a peer has been disconnected
|
||||
:param peer_id: id of peer to remove
|
||||
|
@ -96,16 +132,18 @@ class GossipSub(IPubsubRouter):
|
|||
peer_id_str = str(peer_id)
|
||||
self.peers_to_protocol.remove(peer_id_str)
|
||||
|
||||
async def handle_rpc(self, rpc, sender_peer_id):
|
||||
# FIXME: type of `sender_peer_id` should be changed to `ID`
|
||||
async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: str):
|
||||
"""
|
||||
Invoked to process control messages in the RPC envelope.
|
||||
It is invoked after subscriptions and payload messages have been processed
|
||||
:param rpc: rpc message
|
||||
:param rpc: RPC message
|
||||
:param sender_peer_id: id of the peer who sent the message
|
||||
"""
|
||||
control_message = rpc.control
|
||||
sender_peer_id = str(sender_peer_id)
|
||||
|
||||
# Relay each rpc control to the appropriate handler
|
||||
# Relay each rpc control message to the appropriate handler
|
||||
if control_message.ihave:
|
||||
for ihave in control_message.ihave:
|
||||
await self.handle_ihave(ihave, sender_peer_id)
|
||||
|
@ -191,7 +229,7 @@ class GossipSub(IPubsubRouter):
|
|||
# Excludes `msg_forwarder` and `origin`
|
||||
yield from send_to.difference([msg_forwarder, origin])
|
||||
|
||||
async def join(self, topic):
|
||||
async def join(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
"""
|
||||
Join notifies the router that we want to receive and
|
||||
|
@ -204,8 +242,9 @@ class GossipSub(IPubsubRouter):
|
|||
# Create mesh[topic] if it does not yet exist
|
||||
self.mesh[topic] = []
|
||||
|
||||
topic_in_fanout = topic in self.fanout
|
||||
fanout_peers = self.fanout[topic] if topic_in_fanout else []
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
# FIXME: Should be changed to `List[ID]`
|
||||
fanout_peers: List[str] = self.fanout[topic] if topic_in_fanout else []
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
|
||||
# There are less than D peers (let this number be x)
|
||||
|
@ -229,7 +268,7 @@ class GossipSub(IPubsubRouter):
|
|||
if topic_in_fanout:
|
||||
del self.fanout[topic]
|
||||
|
||||
async def leave(self, topic):
|
||||
async def leave(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
"""
|
||||
Leave notifies the router that we are no longer interested in a topic.
|
||||
|
@ -247,7 +286,7 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
# Interface Helper Functions
|
||||
@staticmethod
|
||||
def get_peer_type(protocol_id):
|
||||
def get_peer_type(protocol_id: str) -> str:
|
||||
# TODO: Do this in a better, more efficient way
|
||||
if "gossipsub" in protocol_id:
|
||||
return "gossip"
|
||||
|
@ -255,7 +294,13 @@ class GossipSub(IPubsubRouter):
|
|||
return "flood"
|
||||
return "unknown"
|
||||
|
||||
async def deliver_messages_to_peers(self, peers, msg_sender, origin_id, serialized_packet):
|
||||
# FIXME: type of `peers` should be changed to `List[ID]`
|
||||
# FIXME: type of `msg_sender` and `origin_id` should be changed to `ID`
|
||||
async def deliver_messages_to_peers(self,
|
||||
peers: List[str],
|
||||
msg_sender: str,
|
||||
origin_id: str,
|
||||
serialized_packet: bytes):
|
||||
for peer_id_in_topic in peers:
|
||||
# Forward to all peers that are not the
|
||||
# message sender and are not the message origin
|
||||
|
@ -267,7 +312,7 @@ class GossipSub(IPubsubRouter):
|
|||
await stream.write(serialized_packet)
|
||||
|
||||
# Heartbeat
|
||||
async def heartbeat(self):
|
||||
async def heartbeat(self) -> None:
|
||||
"""
|
||||
Call individual heartbeats.
|
||||
Note: the heartbeats are called with awaits because each heartbeat depends on the
|
||||
|
@ -281,7 +326,7 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def mesh_heartbeat(self):
|
||||
async def mesh_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.mesh:
|
||||
# Skip if no peers have subscribed to the topic
|
||||
|
@ -297,7 +342,8 @@ class GossipSub(IPubsubRouter):
|
|||
self.mesh[topic],
|
||||
)
|
||||
|
||||
fanout_peers_not_in_mesh = [
|
||||
# FIXME: Should be changed to `List[ID]`
|
||||
fanout_peers_not_in_mesh: List[str] = [
|
||||
peer
|
||||
for peer in selected_peers
|
||||
if peer not in self.mesh[topic]
|
||||
|
@ -311,8 +357,12 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
if num_mesh_peers_in_topic > self.degree_high:
|
||||
# Select |mesh[topic]| - D peers from mesh[topic]
|
||||
selected_peers = GossipSub.select_from_minus(num_mesh_peers_in_topic - self.degree,
|
||||
self.mesh[topic], [])
|
||||
# FIXME: Should be changed to `List[ID]`
|
||||
selected_peers: List[str] = GossipSub.select_from_minus(
|
||||
num_mesh_peers_in_topic - self.degree,
|
||||
self.mesh[topic],
|
||||
[],
|
||||
)
|
||||
for peer in selected_peers:
|
||||
# Remove peer from mesh[topic]
|
||||
self.mesh[topic].remove(peer)
|
||||
|
@ -320,7 +370,7 @@ class GossipSub(IPubsubRouter):
|
|||
# Emit PRUNE(topic) control message to peer
|
||||
await self.emit_prune(topic, peer)
|
||||
|
||||
async def fanout_heartbeat(self):
|
||||
async def fanout_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.fanout:
|
||||
# If time since last published > ttl
|
||||
|
@ -362,14 +412,14 @@ class GossipSub(IPubsubRouter):
|
|||
# TODO: this line is a monster, can hopefully be simplified
|
||||
if (topic not in self.mesh or (peer not in self.mesh[topic]))\
|
||||
and (topic not in self.fanout or (peer not in self.fanout[topic])):
|
||||
msg_ids = [str(msg) for msg in msg_ids]
|
||||
msg_ids: List[str] = [str(msg) for msg in msg_ids]
|
||||
await self.emit_ihave(topic, msg_ids, peer)
|
||||
|
||||
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
|
||||
# Do the same for fanout, for all topics not already hit in mesh
|
||||
for topic in self.fanout:
|
||||
if topic not in self.mesh:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
msg_ids: List[str] = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# TODO: Make more efficient, possibly using a generator?
|
||||
# Get all pubsub peers in topic and only add if they are gossipsub peers also
|
||||
|
@ -383,13 +433,13 @@ class GossipSub(IPubsubRouter):
|
|||
for peer in peers_to_emit_ihave_to:
|
||||
if peer not in self.mesh[topic] and peer not in self.fanout[topic]:
|
||||
|
||||
msg_ids = [str(msg) for msg in msg_ids]
|
||||
msg_ids: List[str] = [str(msg) for msg in msg_ids]
|
||||
await self.emit_ihave(topic, msg_ids, peer)
|
||||
|
||||
self.mcache.shift()
|
||||
|
||||
@staticmethod
|
||||
def select_from_minus(num_to_select, pool, minus):
|
||||
def select_from_minus(num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]) -> List[Any]:
|
||||
"""
|
||||
Select at most num_to_select subset of elements from the set (pool - minus) randomly.
|
||||
:param num_to_select: number of elements to randomly select
|
||||
|
@ -400,10 +450,10 @@ class GossipSub(IPubsubRouter):
|
|||
# Create selection pool, which is selection_pool = pool - minus
|
||||
if minus:
|
||||
# Create a new selection pool by removing elements of minus
|
||||
selection_pool = [x for x in pool if x not in minus]
|
||||
selection_pool: List[Any] = [x for x in pool if x not in minus]
|
||||
else:
|
||||
# Don't create a new selection_pool if we are not subbing anything
|
||||
selection_pool = pool
|
||||
selection_pool: List[Any] = pool
|
||||
|
||||
# If num_to_select > size(selection_pool), then return selection_pool (which has the most
|
||||
# possible elements s.t. the number of elements is less than num_to_select)
|
||||
|
@ -411,7 +461,7 @@ class GossipSub(IPubsubRouter):
|
|||
return selection_pool
|
||||
|
||||
# Random selection
|
||||
selection = random.sample(selection_pool, num_to_select)
|
||||
selection: List[Any] = random.sample(selection_pool, num_to_select)
|
||||
|
||||
return selection
|
||||
|
||||
|
@ -433,7 +483,7 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
# RPC handlers
|
||||
|
||||
async def handle_ihave(self, ihave_msg, sender_peer_id):
|
||||
async def handle_ihave(self, ihave_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
|
||||
"""
|
||||
Checks the seen set and requests unknown messages with an IWANT message.
|
||||
"""
|
||||
|
@ -442,29 +492,36 @@ class GossipSub(IPubsubRouter):
|
|||
from_id_str = sender_peer_id
|
||||
|
||||
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache
|
||||
seen_seqnos_and_peers = [seqno_and_from
|
||||
for seqno_and_from in self.pubsub.seen_messages.keys()]
|
||||
seen_seqnos_and_peers = [
|
||||
seqno_and_from
|
||||
for seqno_and_from in self.pubsub.seen_messages.keys()
|
||||
]
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in seen_seqnos) to list
|
||||
# of messages we want to request
|
||||
msg_ids_wanted = [msg_id for msg_id in ihave_msg.messageIDs
|
||||
if literal_eval(msg_id) not in seen_seqnos_and_peers]
|
||||
# FIXME: Update type of message ID
|
||||
msg_ids_wanted = [
|
||||
msg_id
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if literal_eval(msg_id) not in seen_seqnos_and_peers
|
||||
]
|
||||
|
||||
# Request messages with IWANT message
|
||||
if msg_ids_wanted:
|
||||
await self.emit_iwant(msg_ids_wanted, from_id_str)
|
||||
|
||||
async def handle_iwant(self, iwant_msg, sender_peer_id):
|
||||
async def handle_iwant(self, iwant_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
|
||||
"""
|
||||
Forwards all request messages that are present in mcache to the requesting peer.
|
||||
"""
|
||||
from_id_str = sender_peer_id
|
||||
|
||||
msg_ids = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
||||
msgs_to_forward = []
|
||||
# FIXME: Update type of message ID
|
||||
msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
||||
msgs_to_forward: List = []
|
||||
for msg_id_iwant in msg_ids:
|
||||
# Check if the wanted message ID is present in mcache
|
||||
msg = self.mcache.get(msg_id_iwant)
|
||||
msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)
|
||||
|
||||
# Cache hit
|
||||
if msg:
|
||||
|
@ -476,12 +533,12 @@ class GossipSub(IPubsubRouter):
|
|||
# because then the message will forwarded to peers in the topics contained in the messages.
|
||||
# We should
|
||||
# 1) Package these messages into a single packet
|
||||
packet = rpc_pb2.RPC()
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
||||
packet.publish.extend(msgs_to_forward)
|
||||
|
||||
# 2) Serialize that packet
|
||||
rpc_msg = packet.SerializeToString()
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# 3) Get the stream to this peer
|
||||
# TODO: Should we pass in from_id or from_id_str here?
|
||||
|
@ -490,8 +547,8 @@ class GossipSub(IPubsubRouter):
|
|||
# 4) And write the packet to the stream
|
||||
await peer_stream.write(rpc_msg)
|
||||
|
||||
async def handle_graft(self, graft_msg, sender_peer_id):
|
||||
topic = graft_msg.topicID
|
||||
async def handle_graft(self, graft_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
|
||||
topic: str = graft_msg.topicID
|
||||
|
||||
from_id_str = sender_peer_id
|
||||
|
||||
|
@ -503,8 +560,8 @@ class GossipSub(IPubsubRouter):
|
|||
# Respond with PRUNE if not subscribed to the topic
|
||||
await self.emit_prune(topic, sender_peer_id)
|
||||
|
||||
async def handle_prune(self, prune_msg, sender_peer_id):
|
||||
topic = prune_msg.topicID
|
||||
async def handle_prune(self, prune_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
|
||||
topic: str = prune_msg.topicID
|
||||
|
||||
from_id_str = sender_peer_id
|
||||
|
||||
|
@ -514,65 +571,65 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
# RPC emitters
|
||||
|
||||
async def emit_ihave(self, topic, msg_ids, to_peer):
|
||||
async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: str) -> None:
|
||||
"""
|
||||
Emit ihave message, sent to to_peer, for topic and msg_ids
|
||||
"""
|
||||
|
||||
ihave_msg = rpc_pb2.ControlIHave()
|
||||
ihave_msg: rpc_pb2.ControlIHave = rpc_pb2.ControlIHave()
|
||||
ihave_msg.messageIDs.extend(msg_ids)
|
||||
ihave_msg.topicID = topic
|
||||
|
||||
control_msg = rpc_pb2.ControlMessage()
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.ihave.extend([ihave_msg])
|
||||
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_iwant(self, msg_ids, to_peer):
|
||||
async def emit_iwant(self, msg_ids: Any, to_peer: str) -> None:
|
||||
"""
|
||||
Emit iwant message, sent to to_peer, for msg_ids
|
||||
"""
|
||||
|
||||
iwant_msg = rpc_pb2.ControlIWant()
|
||||
iwant_msg: rpc_pb2.ControlIWant = rpc_pb2.ControlIWant()
|
||||
iwant_msg.messageIDs.extend(msg_ids)
|
||||
|
||||
control_msg = rpc_pb2.ControlMessage()
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.iwant.extend([iwant_msg])
|
||||
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_graft(self, topic, to_peer):
|
||||
async def emit_graft(self, topic: str, to_peer: str) -> None:
|
||||
"""
|
||||
Emit graft message, sent to to_peer, for topic
|
||||
"""
|
||||
|
||||
graft_msg = rpc_pb2.ControlGraft()
|
||||
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft()
|
||||
graft_msg.topicID = topic
|
||||
|
||||
control_msg = rpc_pb2.ControlMessage()
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.graft.extend([graft_msg])
|
||||
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_prune(self, topic, to_peer):
|
||||
async def emit_prune(self, topic: str, to_peer: str) -> None:
|
||||
"""
|
||||
Emit graft message, sent to to_peer, for topic
|
||||
"""
|
||||
|
||||
prune_msg = rpc_pb2.ControlPrune()
|
||||
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune()
|
||||
prune_msg.topicID = topic
|
||||
|
||||
control_msg = rpc_pb2.ControlMessage()
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
control_msg.prune.extend([prune_msg])
|
||||
|
||||
await self.emit_control_message(control_msg, to_peer)
|
||||
|
||||
async def emit_control_message(self, control_msg, to_peer):
|
||||
async def emit_control_message(self, control_msg: rpc_pb2.ControlMessage, to_peer: str) -> None:
|
||||
# Add control message to packet
|
||||
packet = rpc_pb2.RPC()
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
packet.control.CopyFrom(control_msg)
|
||||
|
||||
rpc_msg = packet.SerializeToString()
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
|
Loading…
Reference in New Issue
Block a user