Update peer_id to type peer.ID in pubsub folder

This commit is contained in:
NIC619 2019-08-01 12:05:28 +08:00
parent 9562cb2a46
commit cd684aad9e
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17
6 changed files with 77 additions and 110 deletions

View File

@ -73,7 +73,7 @@ class FloodSub(IPubsubRouter):
) )
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
for peer_id in peers_gen: for peer_id in peers_gen:
stream = self.pubsub.peers[str(peer_id)] stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages. # FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
await stream.write(rpc_msg.SerializeToString()) await stream.write(rpc_msg.SerializeToString())
@ -105,11 +105,9 @@ class FloodSub(IPubsubRouter):
for topic in topic_ids: for topic in topic_ids:
if topic not in self.pubsub.peer_topics: if topic not in self.pubsub.peer_topics:
continue continue
for peer_id_str in self.pubsub.peer_topics[topic]: for peer_id in self.pubsub.peer_topics[topic]:
peer_id = ID.from_base58(peer_id_str)
if peer_id in (msg_forwarder, origin): if peer_id in (msg_forwarder, origin):
continue continue
# FIXME: Should change `self.pubsub.peers` to Dict[PeerID, ...] if peer_id not in self.pubsub.peers:
if str(peer_id) not in self.pubsub.peers:
continue continue
yield peer_id yield peer_id

View File

@ -32,20 +32,15 @@ class GossipSub(IPubsubRouter):
time_to_live: int time_to_live: int
# FIXME: Should be changed to `Dict[str, List[ID]]` mesh: Dict[str, List[ID]]
mesh: Dict[str, List[str]] fanout: Dict[str, List[ID]]
# FIXME: Should be changed to `Dict[str, List[ID]]`
fanout: Dict[str, List[str]]
# FIXME: Should be changed to `Dict[ID, str]` peers_to_protocol: Dict[ID, str]
peers_to_protocol: Dict[str, str]
time_since_last_publish: Dict[str, int] time_since_last_publish: Dict[str, int]
# FIXME: Should be changed to List[ID] peers_gossipsub: List[ID]
peers_gossipsub: List[str] peers_floodsub: List[ID]
# FIXME: Should be changed to List[ID]
peers_floodsub: List[str]
mcache: MessageCache mcache: MessageCache
@ -122,27 +117,25 @@ class GossipSub(IPubsubRouter):
# Add peer to the correct peer list # Add peer to the correct peer list
peer_type = GossipSub.get_peer_type(protocol_id) peer_type = GossipSub.get_peer_type(protocol_id)
peer_id_str = str(peer_id)
self.peers_to_protocol[peer_id_str] = protocol_id self.peers_to_protocol[peer_id] = protocol_id
if peer_type == "gossip": if peer_type == "gossip":
self.peers_gossipsub.append(peer_id_str) self.peers_gossipsub.append(peer_id)
elif peer_type == "flood": elif peer_type == "flood":
self.peers_floodsub.append(peer_id_str) self.peers_floodsub.append(peer_id)
def remove_peer(self, peer_id: ID) -> None: def remove_peer(self, peer_id: ID) -> None:
""" """
Notifies the router that a peer has been disconnected Notifies the router that a peer has been disconnected
:param peer_id: id of peer to remove :param peer_id: id of peer to remove
""" """
peer_id_str = str(peer_id) del self.peers_to_protocol[peer_id]
del self.peers_to_protocol[peer_id_str]
if peer_id_str in self.peers_gossipsub: if peer_id in self.peers_gossipsub:
self.peers_gossipsub.remove(peer_id_str) self.peers_gossipsub.remove(peer_id)
if peer_id_str in self.peers_gossipsub: if peer_id in self.peers_gossipsub:
self.peers_floodsub.remove(peer_id_str) self.peers_floodsub.remove(peer_id)
async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: ID) -> None: async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: ID) -> None:
""" """
@ -152,21 +145,20 @@ class GossipSub(IPubsubRouter):
:param sender_peer_id: id of the peer who sent the message :param sender_peer_id: id of the peer who sent the message
""" """
control_message = rpc.control control_message = rpc.control
sender_peer_id_str = str(sender_peer_id)
# Relay each rpc control message to the appropriate handler # Relay each rpc control message to the appropriate handler
if control_message.ihave: if control_message.ihave:
for ihave in control_message.ihave: for ihave in control_message.ihave:
await self.handle_ihave(ihave, sender_peer_id_str) await self.handle_ihave(ihave, sender_peer_id)
if control_message.iwant: if control_message.iwant:
for iwant in control_message.iwant: for iwant in control_message.iwant:
await self.handle_iwant(iwant, sender_peer_id_str) await self.handle_iwant(iwant, sender_peer_id)
if control_message.graft: if control_message.graft:
for graft in control_message.graft: for graft in control_message.graft:
await self.handle_graft(graft, sender_peer_id_str) await self.handle_graft(graft, sender_peer_id)
if control_message.prune: if control_message.prune:
for prune in control_message.prune: for prune in control_message.prune:
await self.handle_prune(prune, sender_peer_id_str) await self.handle_prune(prune, sender_peer_id)
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
@ -182,7 +174,7 @@ class GossipSub(IPubsubRouter):
) )
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
for peer_id in peers_gen: for peer_id in peers_gen:
stream = self.pubsub.peers[str(peer_id)] stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages. # FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages. # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
@ -204,16 +196,14 @@ class GossipSub(IPubsubRouter):
continue continue
# floodsub peers # floodsub peers
for peer_id_str in self.pubsub.peer_topics[topic]: for peer_id in self.pubsub.peer_topics[topic]:
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go. # FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
# This will improve the efficiency when searching for a peer's protocol id. # This will improve the efficiency when searching for a peer's protocol id.
if peer_id_str in self.peers_floodsub: if peer_id in self.peers_floodsub:
peer_id = ID.from_base58(peer_id_str)
send_to.add(peer_id) send_to.add(peer_id)
# gossipsub peers # gossipsub peers
# FIXME: Change `str` to `ID` in_topic_gossipsub_peers: List[ID] = None
in_topic_gossipsub_peers: List[str] = None
# TODO: Do we need to check `topic in self.pubsub.my_topics`? # TODO: Do we need to check `topic in self.pubsub.my_topics`?
if topic in self.mesh: if topic in self.mesh:
in_topic_gossipsub_peers = self.mesh[topic] in_topic_gossipsub_peers = self.mesh[topic]
@ -229,8 +219,8 @@ class GossipSub(IPubsubRouter):
topic, self.degree, [] topic, self.degree, []
) )
in_topic_gossipsub_peers = self.fanout[topic] in_topic_gossipsub_peers = self.fanout[topic]
for peer_id_str in in_topic_gossipsub_peers: for peer_id in in_topic_gossipsub_peers:
send_to.add(ID.from_base58(peer_id_str)) send_to.add(peer_id)
# Excludes `msg_forwarder` and `origin` # Excludes `msg_forwarder` and `origin`
yield from send_to.difference([msg_forwarder, origin]) yield from send_to.difference([msg_forwarder, origin])
@ -248,8 +238,7 @@ class GossipSub(IPubsubRouter):
self.mesh[topic] = [] self.mesh[topic] = []
topic_in_fanout: bool = topic in self.fanout topic_in_fanout: bool = topic in self.fanout
# FIXME: Should be changed to `List[ID]` fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
fanout_peers: List[str] = self.fanout[topic] if topic_in_fanout else []
fanout_size = len(fanout_peers) fanout_size = len(fanout_peers)
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree): if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
# There are less than D peers (let this number be x) # There are less than D peers (let this number be x)
@ -297,13 +286,11 @@ class GossipSub(IPubsubRouter):
return "flood" return "flood"
return "unknown" return "unknown"
# FIXME: type of `peers` should be changed to `List[ID]`
# FIXME: type of `msg_sender` and `origin_id` should be changed to `ID`
async def deliver_messages_to_peers( async def deliver_messages_to_peers(
self, self,
peers: List[str], peers: List[ID],
msg_sender: str, msg_sender: ID,
origin_id: str, origin_id: ID,
serialized_packet: bytes, serialized_packet: bytes,
) -> None: ) -> None:
for peer_id_in_topic in peers: for peer_id_in_topic in peers:
@ -345,8 +332,7 @@ class GossipSub(IPubsubRouter):
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic] topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
) )
# FIXME: Should be changed to `List[ID]` fanout_peers_not_in_mesh: List[ID] = [
fanout_peers_not_in_mesh: List[str] = [
peer for peer in selected_peers if peer not in self.mesh[topic] peer for peer in selected_peers if peer not in self.mesh[topic]
] ]
for peer in fanout_peers_not_in_mesh: for peer in fanout_peers_not_in_mesh:
@ -358,7 +344,6 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high: if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic] # Select |mesh[topic]| - D peers from mesh[topic]
# FIXME: Should be changed to `List[ID]`
selected_peers = GossipSub.select_from_minus( selected_peers = GossipSub.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], [] num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
) )
@ -468,15 +453,13 @@ class GossipSub(IPubsubRouter):
return selection return selection
# FIXME: type of `minus` should be changed to type `Sequence[ID]`
# FIXME: return type should be changed to type `List[ID]`
def _get_in_topic_gossipsub_peers_from_minus( def _get_in_topic_gossipsub_peers_from_minus(
self, topic: str, num_to_select: int, minus: Sequence[str] self, topic: str, num_to_select: int, minus: Sequence[ID]
) -> List[str]: ) -> List[ID]:
gossipsub_peers_in_topic = [ gossipsub_peers_in_topic = [
peer_str peer_id
for peer_str in self.pubsub.peer_topics[topic] for peer_id in self.pubsub.peer_topics[topic]
if peer_str in self.peers_gossipsub if peer_id in self.peers_gossipsub
] ]
return self.select_from_minus( return self.select_from_minus(
num_to_select, gossipsub_peers_in_topic, list(minus) num_to_select, gossipsub_peers_in_topic, list(minus)
@ -485,15 +468,13 @@ class GossipSub(IPubsubRouter):
# RPC handlers # RPC handlers
async def handle_ihave( async def handle_ihave(
self, ihave_msg: rpc_pb2.Message, sender_peer_id: str self, ihave_msg: rpc_pb2.Message, sender_peer_id: ID
) -> None: ) -> None:
""" """
Checks the seen set and requests unknown messages with an IWANT message. Checks the seen set and requests unknown messages with an IWANT message.
""" """
# from_id_bytes = ihave_msg.from_id # from_id_bytes = ihave_msg.from_id
from_id_str = sender_peer_id
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache # Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache
seen_seqnos_and_peers = [ seen_seqnos_and_peers = [
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.keys() seqno_and_from for seqno_and_from in self.pubsub.seen_messages.keys()
@ -510,16 +491,14 @@ class GossipSub(IPubsubRouter):
# Request messages with IWANT message # Request messages with IWANT message
if msg_ids_wanted: if msg_ids_wanted:
await self.emit_iwant(msg_ids_wanted, from_id_str) await self.emit_iwant(msg_ids_wanted, sender_peer_id)
async def handle_iwant( async def handle_iwant(
self, iwant_msg: rpc_pb2.Message, sender_peer_id: str self, iwant_msg: rpc_pb2.Message, sender_peer_id: ID
) -> None: ) -> None:
""" """
Forwards all request messages that are present in mcache to the requesting peer. Forwards all request messages that are present in mcache to the requesting peer.
""" """
from_id_str = sender_peer_id
# FIXME: Update type of message ID # FIXME: Update type of message ID
# FIXME: Find a better way to parse the msg ids # FIXME: Find a better way to parse the msg ids
msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs] msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
@ -546,41 +525,36 @@ class GossipSub(IPubsubRouter):
rpc_msg: bytes = packet.SerializeToString() rpc_msg: bytes = packet.SerializeToString()
# 3) Get the stream to this peer # 3) Get the stream to this peer
# TODO: Should we pass in from_id or from_id_str here? peer_stream = self.pubsub.peers[sender_peer_id]
peer_stream = self.pubsub.peers[from_id_str]
# 4) And write the packet to the stream # 4) And write the packet to the stream
await peer_stream.write(rpc_msg) await peer_stream.write(rpc_msg)
async def handle_graft( async def handle_graft(
self, graft_msg: rpc_pb2.Message, sender_peer_id: str self, graft_msg: rpc_pb2.Message, sender_peer_id: ID
) -> None: ) -> None:
topic: str = graft_msg.topicID topic: str = graft_msg.topicID
from_id_str = sender_peer_id
# Add peer to mesh for topic # Add peer to mesh for topic
if topic in self.mesh: if topic in self.mesh:
if from_id_str not in self.mesh[topic]: if sender_peer_id not in self.mesh[topic]:
self.mesh[topic].append(from_id_str) self.mesh[topic].append(sender_peer_id)
else: else:
# Respond with PRUNE if not subscribed to the topic # Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id) await self.emit_prune(topic, sender_peer_id)
async def handle_prune( async def handle_prune(
self, prune_msg: rpc_pb2.Message, sender_peer_id: str self, prune_msg: rpc_pb2.Message, sender_peer_id: ID
) -> None: ) -> None:
topic: str = prune_msg.topicID topic: str = prune_msg.topicID
from_id_str = sender_peer_id
# Remove peer from mesh for topic, if peer is in topic # Remove peer from mesh for topic, if peer is in topic
if topic in self.mesh and from_id_str in self.mesh[topic]: if topic in self.mesh and sender_peer_id in self.mesh[topic]:
self.mesh[topic].remove(from_id_str) self.mesh[topic].remove(sender_peer_id)
# RPC emitters # RPC emitters
async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: str) -> None: async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: ID) -> None:
""" """
Emit ihave message, sent to to_peer, for topic and msg_ids Emit ihave message, sent to to_peer, for topic and msg_ids
""" """
@ -594,7 +568,7 @@ class GossipSub(IPubsubRouter):
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_iwant(self, msg_ids: Any, to_peer: str) -> None: async def emit_iwant(self, msg_ids: Any, to_peer: ID) -> None:
""" """
Emit iwant message, sent to to_peer, for msg_ids Emit iwant message, sent to to_peer, for msg_ids
""" """
@ -607,7 +581,7 @@ class GossipSub(IPubsubRouter):
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_graft(self, topic: str, to_peer: str) -> None: async def emit_graft(self, topic: str, to_peer: ID) -> None:
""" """
Emit graft message, sent to to_peer, for topic Emit graft message, sent to to_peer, for topic
""" """
@ -620,7 +594,7 @@ class GossipSub(IPubsubRouter):
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_prune(self, topic: str, to_peer: str) -> None: async def emit_prune(self, topic: str, to_peer: ID) -> None:
""" """
Emit graft message, sent to to_peer, for topic Emit graft message, sent to to_peer, for topic
""" """
@ -634,7 +608,7 @@ class GossipSub(IPubsubRouter):
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_control_message( async def emit_control_message(
self, control_msg: rpc_pb2.ControlMessage, to_peer: str self, control_msg: rpc_pb2.ControlMessage, to_peer: ID
) -> None: ) -> None:
# Add control message to packet # Add control message to packet
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()

View File

@ -41,10 +41,8 @@ class Pubsub:
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
# FIXME: Should be changed to `Dict[str, List[ID]]` peer_topics: Dict[str, List[ID]]
peer_topics: Dict[str, List[str]] peers: Dict[ID, INetStream]
# FIXME: Should be changed to `Dict[ID, INetStream]`
peers: Dict[str, INetStream]
# NOTE: Be sure it is increased atomically everytime. # NOTE: Be sure it is increased atomically everytime.
counter: int # uint64 counter: int # uint64
@ -93,11 +91,9 @@ class Pubsub:
self.my_topics = {} self.my_topics = {}
# Map of topic to peers to keep track of what peers are subscribed to # Map of topic to peers to keep track of what peers are subscribed to
# FIXME: Should be changed to `Dict[str, ID]`
self.peer_topics = {} self.peer_topics = {}
# Create peers map, which maps peer_id (as string) to stream (to a given peer) # Create peers map, which maps peer_id (as string) to stream (to a given peer)
# FIXME: Should be changed to `Dict[ID, INetStream]`
self.peers = {} self.peers = {}
self.counter = time.time_ns() self.counter = time.time_ns()
@ -168,7 +164,7 @@ class Pubsub:
# Add peer # Add peer
# Map peer to stream # Map peer to stream
peer_id: ID = stream.mplex_conn.peer_id peer_id: ID = stream.mplex_conn.peer_id
self.peers[str(peer_id)] = stream self.peers[peer_id] = stream
self.router.add_peer(peer_id, stream.get_protocol()) self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet # Send hello packet
@ -198,7 +194,7 @@ class Pubsub:
# Add Peer # Add Peer
# Map peer to stream # Map peer to stream
self.peers[str(peer_id)] = stream self.peers[peer_id] = stream
self.router.add_peer(peer_id, stream.get_protocol()) self.router.add_peer(peer_id, stream.get_protocol())
# Send hello packet # Send hello packet
@ -223,17 +219,16 @@ class Pubsub:
:param origin_id: id of the peer who subscribe to the message :param origin_id: id of the peer who subscribe to the message
:param sub_message: RPC.SubOpts :param sub_message: RPC.SubOpts
""" """
origin_id_str = str(origin_id)
if sub_message.subscribe: if sub_message.subscribe:
if sub_message.topicid not in self.peer_topics: if sub_message.topicid not in self.peer_topics:
self.peer_topics[sub_message.topicid] = [origin_id_str] self.peer_topics[sub_message.topicid] = [origin_id]
elif origin_id_str not in self.peer_topics[sub_message.topicid]: elif origin_id not in self.peer_topics[sub_message.topicid]:
# Add peer to topic # Add peer to topic
self.peer_topics[sub_message.topicid].append(origin_id_str) self.peer_topics[sub_message.topicid].append(origin_id)
else: else:
if sub_message.topicid in self.peer_topics: if sub_message.topicid in self.peer_topics:
if origin_id_str in self.peer_topics[sub_message.topicid]: if origin_id in self.peer_topics[sub_message.topicid]:
self.peer_topics[sub_message.topicid].remove(origin_id_str) self.peer_topics[sub_message.topicid].remove(origin_id)
# FIXME(mhchia): Change the function name? # FIXME(mhchia): Change the function name?
# FIXME(mhchia): `publish_message` can be further type hinted with mypy_protobuf # FIXME(mhchia): `publish_message` can be further type hinted with mypy_protobuf

View File

@ -179,12 +179,12 @@ async def perform_test_from_obj(obj, router_factory):
node_map = {} node_map = {}
pubsub_map = {} pubsub_map = {}
async def add_node(node_id: str) -> None: async def add_node(node_id_str: str) -> None:
pubsub_router = router_factory(protocols=obj["supported_protocols"]) pubsub_router = router_factory(protocols=obj["supported_protocols"])
pubsub = PubsubFactory(router=pubsub_router) pubsub = PubsubFactory(router=pubsub_router)
await pubsub.host.get_network().listen(LISTEN_MADDR) await pubsub.host.get_network().listen(LISTEN_MADDR)
node_map[node_id] = pubsub.host node_map[node_id_str] = pubsub.host
pubsub_map[node_id] = pubsub pubsub_map[node_id_str] = pubsub
tasks_connect = [] tasks_connect = []
for start_node_id in adj_list: for start_node_id in adj_list:

View File

@ -54,11 +54,11 @@ async def test_join(num_hosts, hosts, gossipsubs, pubsubs_gsub):
for i in hosts_indices: for i in hosts_indices:
if i in subscribed_peer_indices: if i in subscribed_peer_indices:
assert str(hosts[i].get_id()) in gossipsubs[central_node_index].mesh[topic] assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert str(hosts[central_node_index].get_id()) in gossipsubs[i].mesh[topic] assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else: else:
assert ( assert (
str(hosts[i].get_id()) not in gossipsubs[central_node_index].mesh[topic] hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
) )
assert topic not in gossipsubs[i].mesh assert topic not in gossipsubs[i].mesh
@ -89,9 +89,9 @@ async def test_leave(pubsubs_gsub):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_graft(pubsubs_gsub, hosts, gossipsubs, event_loop, monkeypatch): async def test_handle_graft(pubsubs_gsub, hosts, gossipsubs, event_loop, monkeypatch):
index_alice = 0 index_alice = 0
id_alice = str(hosts[index_alice].get_id()) id_alice = hosts[index_alice].get_id()
index_bob = 1 index_bob = 1
id_bob = str(hosts[index_bob].get_id()) id_bob = hosts[index_bob].get_id()
await connect(hosts[index_alice], hosts[index_bob]) await connect(hosts[index_alice], hosts[index_bob])
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
@ -141,9 +141,9 @@ async def test_handle_graft(pubsubs_gsub, hosts, gossipsubs, event_loop, monkeyp
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_prune(pubsubs_gsub, hosts, gossipsubs): async def test_handle_prune(pubsubs_gsub, hosts, gossipsubs):
index_alice = 0 index_alice = 0
id_alice = str(hosts[index_alice].get_id()) id_alice = hosts[index_alice].get_id()
index_bob = 1 index_bob = 1
id_bob = str(hosts[index_bob].get_id()) id_bob = hosts[index_bob].get_id()
topic = "test_handle_prune" topic = "test_handle_prune"
for pubsub in pubsubs_gsub: for pubsub in pubsubs_gsub:

View File

@ -60,11 +60,11 @@ async def test_peers_subscribe(pubsubs_fsub):
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1 # Yield to let 0 notify 1
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1 # Yield to let 0 notify 1
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize("num_hosts", (1,))
@ -212,23 +212,23 @@ def test_handle_subscription(pubsubs_fsub):
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
) )
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Another peer is subscribed # Test: Another peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
assert len(pubsubs_fsub[0].peer_topics) == 1 assert len(pubsubs_fsub[0].peer_topics) == 1
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
assert str(peer_ids[1]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Subscribe to another topic # Test: Subscribe to another topic
another_topic = "ANOTHER_TOPIC" another_topic = "ANOTHER_TOPIC"
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
assert len(pubsubs_fsub[0].peer_topics) == 2 assert len(pubsubs_fsub[0].peer_topics) == 2
assert another_topic in pubsubs_fsub[0].peer_topics assert another_topic in pubsubs_fsub[0].peer_topics
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[another_topic] assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
# Test: unsubscribe # Test: unsubscribe
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
assert str(peer_ids[0]) not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize("num_hosts", (1,))
@ -261,7 +261,7 @@ async def test_handle_talk(pubsubs_fsub):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_all_peers(pubsubs_fsub, monkeypatch): async def test_message_all_peers(pubsubs_fsub, monkeypatch):
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)]
mock_peers = {str(peer_id): FakeNetStream() for peer_id in peer_ids} mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids}
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC() empty_rpc = rpc_pb2.RPC()