diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index d09a7de..757a4fa 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -73,7 +73,7 @@ class FloodSub(IPubsubRouter): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) for peer_id in peers_gen: - stream = self.pubsub.peers[str(peer_id)] + stream = self.pubsub.peers[peer_id] # FIXME: We should add a `WriteMsg` similar to write delimited messages. # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 await stream.write(rpc_msg.SerializeToString()) @@ -105,11 +105,9 @@ class FloodSub(IPubsubRouter): for topic in topic_ids: if topic not in self.pubsub.peer_topics: continue - for peer_id_str in self.pubsub.peer_topics[topic]: - peer_id = ID.from_base58(peer_id_str) + for peer_id in self.pubsub.peer_topics[topic]: if peer_id in (msg_forwarder, origin): continue - # FIXME: Should change `self.pubsub.peers` to Dict[PeerID, ...] - if str(peer_id) not in self.pubsub.peers: + if peer_id not in self.pubsub.peers: continue yield peer_id diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index d7db309..5f68a7b 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -32,20 +32,15 @@ class GossipSub(IPubsubRouter): time_to_live: int - # FIXME: Should be changed to `Dict[str, List[ID]]` - mesh: Dict[str, List[str]] - # FIXME: Should be changed to `Dict[str, List[ID]]` - fanout: Dict[str, List[str]] + mesh: Dict[str, List[ID]] + fanout: Dict[str, List[ID]] - # FIXME: Should be changed to `Dict[ID, str]` - peers_to_protocol: Dict[str, str] + peers_to_protocol: Dict[ID, str] time_since_last_publish: Dict[str, int] - # FIXME: Should be changed to List[ID] - peers_gossipsub: List[str] - # FIXME: Should be changed to List[ID] - peers_floodsub: List[str] + peers_gossipsub: List[ID] + peers_floodsub: List[ID] mcache: MessageCache @@ -122,27 +117,25 @@ class GossipSub(IPubsubRouter): # Add peer to the correct peer list peer_type = GossipSub.get_peer_type(protocol_id) - peer_id_str = str(peer_id) - self.peers_to_protocol[peer_id_str] = protocol_id + self.peers_to_protocol[peer_id] = protocol_id if peer_type == "gossip": - self.peers_gossipsub.append(peer_id_str) + self.peers_gossipsub.append(peer_id) elif peer_type == "flood": - self.peers_floodsub.append(peer_id_str) + self.peers_floodsub.append(peer_id) def remove_peer(self, peer_id: ID) -> None: """ Notifies the router that a peer has been disconnected :param peer_id: id of peer to remove """ - peer_id_str = str(peer_id) - del self.peers_to_protocol[peer_id_str] + del self.peers_to_protocol[peer_id] - if peer_id_str in self.peers_gossipsub: - self.peers_gossipsub.remove(peer_id_str) - if peer_id_str in self.peers_gossipsub: - self.peers_floodsub.remove(peer_id_str) + if peer_id in self.peers_gossipsub: + self.peers_gossipsub.remove(peer_id) + if peer_id in self.peers_gossipsub: + self.peers_floodsub.remove(peer_id) async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: ID) -> None: """ @@ -152,21 +145,20 @@ class GossipSub(IPubsubRouter): :param sender_peer_id: id of the peer who sent the message """ control_message = rpc.control - sender_peer_id_str = str(sender_peer_id) # Relay each rpc control message to the appropriate handler if control_message.ihave: for ihave in control_message.ihave: - await self.handle_ihave(ihave, sender_peer_id_str) + await self.handle_ihave(ihave, sender_peer_id) if control_message.iwant: for iwant in control_message.iwant: - await self.handle_iwant(iwant, sender_peer_id_str) + await self.handle_iwant(iwant, sender_peer_id) if control_message.graft: for graft in control_message.graft: - await self.handle_graft(graft, sender_peer_id_str) + await self.handle_graft(graft, sender_peer_id) if control_message.prune: for prune in control_message.prune: - await self.handle_prune(prune, sender_peer_id_str) + await self.handle_prune(prune, sender_peer_id) async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: # pylint: disable=too-many-locals @@ -182,7 +174,7 @@ class GossipSub(IPubsubRouter): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) for peer_id in peers_gen: - stream = self.pubsub.peers[str(peer_id)] + stream = self.pubsub.peers[peer_id] # FIXME: We should add a `WriteMsg` similar to write delimited messages. # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages. @@ -204,16 +196,14 @@ class GossipSub(IPubsubRouter): continue # floodsub peers - for peer_id_str in self.pubsub.peer_topics[topic]: + for peer_id in self.pubsub.peer_topics[topic]: # FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go. # This will improve the efficiency when searching for a peer's protocol id. - if peer_id_str in self.peers_floodsub: - peer_id = ID.from_base58(peer_id_str) + if peer_id in self.peers_floodsub: send_to.add(peer_id) # gossipsub peers - # FIXME: Change `str` to `ID` - in_topic_gossipsub_peers: List[str] = None + in_topic_gossipsub_peers: List[ID] = None # TODO: Do we need to check `topic in self.pubsub.my_topics`? if topic in self.mesh: in_topic_gossipsub_peers = self.mesh[topic] @@ -229,8 +219,8 @@ class GossipSub(IPubsubRouter): topic, self.degree, [] ) in_topic_gossipsub_peers = self.fanout[topic] - for peer_id_str in in_topic_gossipsub_peers: - send_to.add(ID.from_base58(peer_id_str)) + for peer_id in in_topic_gossipsub_peers: + send_to.add(peer_id) # Excludes `msg_forwarder` and `origin` yield from send_to.difference([msg_forwarder, origin]) @@ -248,8 +238,7 @@ class GossipSub(IPubsubRouter): self.mesh[topic] = [] topic_in_fanout: bool = topic in self.fanout - # FIXME: Should be changed to `List[ID]` - fanout_peers: List[str] = self.fanout[topic] if topic_in_fanout else [] + fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else [] fanout_size = len(fanout_peers) if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree): # There are less than D peers (let this number be x) @@ -297,13 +286,11 @@ class GossipSub(IPubsubRouter): return "flood" return "unknown" - # FIXME: type of `peers` should be changed to `List[ID]` - # FIXME: type of `msg_sender` and `origin_id` should be changed to `ID` async def deliver_messages_to_peers( self, - peers: List[str], - msg_sender: str, - origin_id: str, + peers: List[ID], + msg_sender: ID, + origin_id: ID, serialized_packet: bytes, ) -> None: for peer_id_in_topic in peers: @@ -345,8 +332,7 @@ class GossipSub(IPubsubRouter): topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic] ) - # FIXME: Should be changed to `List[ID]` - fanout_peers_not_in_mesh: List[str] = [ + fanout_peers_not_in_mesh: List[ID] = [ peer for peer in selected_peers if peer not in self.mesh[topic] ] for peer in fanout_peers_not_in_mesh: @@ -358,7 +344,6 @@ class GossipSub(IPubsubRouter): if num_mesh_peers_in_topic > self.degree_high: # Select |mesh[topic]| - D peers from mesh[topic] - # FIXME: Should be changed to `List[ID]` selected_peers = GossipSub.select_from_minus( num_mesh_peers_in_topic - self.degree, self.mesh[topic], [] ) @@ -468,15 +453,13 @@ class GossipSub(IPubsubRouter): return selection - # FIXME: type of `minus` should be changed to type `Sequence[ID]` - # FIXME: return type should be changed to type `List[ID]` def _get_in_topic_gossipsub_peers_from_minus( - self, topic: str, num_to_select: int, minus: Sequence[str] - ) -> List[str]: + self, topic: str, num_to_select: int, minus: Sequence[ID] + ) -> List[ID]: gossipsub_peers_in_topic = [ - peer_str - for peer_str in self.pubsub.peer_topics[topic] - if peer_str in self.peers_gossipsub + peer_id + for peer_id in self.pubsub.peer_topics[topic] + if peer_id in self.peers_gossipsub ] return self.select_from_minus( num_to_select, gossipsub_peers_in_topic, list(minus) @@ -485,15 +468,13 @@ class GossipSub(IPubsubRouter): # RPC handlers async def handle_ihave( - self, ihave_msg: rpc_pb2.Message, sender_peer_id: str + self, ihave_msg: rpc_pb2.Message, sender_peer_id: ID ) -> None: """ Checks the seen set and requests unknown messages with an IWANT message. """ # from_id_bytes = ihave_msg.from_id - from_id_str = sender_peer_id - # Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache seen_seqnos_and_peers = [ seqno_and_from for seqno_and_from in self.pubsub.seen_messages.keys() @@ -510,16 +491,14 @@ class GossipSub(IPubsubRouter): # Request messages with IWANT message if msg_ids_wanted: - await self.emit_iwant(msg_ids_wanted, from_id_str) + await self.emit_iwant(msg_ids_wanted, sender_peer_id) async def handle_iwant( - self, iwant_msg: rpc_pb2.Message, sender_peer_id: str + self, iwant_msg: rpc_pb2.Message, sender_peer_id: ID ) -> None: """ Forwards all request messages that are present in mcache to the requesting peer. """ - from_id_str = sender_peer_id - # FIXME: Update type of message ID # FIXME: Find a better way to parse the msg ids msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs] @@ -546,41 +525,36 @@ class GossipSub(IPubsubRouter): rpc_msg: bytes = packet.SerializeToString() # 3) Get the stream to this peer - # TODO: Should we pass in from_id or from_id_str here? - peer_stream = self.pubsub.peers[from_id_str] + peer_stream = self.pubsub.peers[sender_peer_id] # 4) And write the packet to the stream await peer_stream.write(rpc_msg) async def handle_graft( - self, graft_msg: rpc_pb2.Message, sender_peer_id: str + self, graft_msg: rpc_pb2.Message, sender_peer_id: ID ) -> None: topic: str = graft_msg.topicID - from_id_str = sender_peer_id - # Add peer to mesh for topic if topic in self.mesh: - if from_id_str not in self.mesh[topic]: - self.mesh[topic].append(from_id_str) + if sender_peer_id not in self.mesh[topic]: + self.mesh[topic].append(sender_peer_id) else: # Respond with PRUNE if not subscribed to the topic await self.emit_prune(topic, sender_peer_id) async def handle_prune( - self, prune_msg: rpc_pb2.Message, sender_peer_id: str + self, prune_msg: rpc_pb2.Message, sender_peer_id: ID ) -> None: topic: str = prune_msg.topicID - from_id_str = sender_peer_id - # Remove peer from mesh for topic, if peer is in topic - if topic in self.mesh and from_id_str in self.mesh[topic]: - self.mesh[topic].remove(from_id_str) + if topic in self.mesh and sender_peer_id in self.mesh[topic]: + self.mesh[topic].remove(sender_peer_id) # RPC emitters - async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: str) -> None: + async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: ID) -> None: """ Emit ihave message, sent to to_peer, for topic and msg_ids """ @@ -594,7 +568,7 @@ class GossipSub(IPubsubRouter): await self.emit_control_message(control_msg, to_peer) - async def emit_iwant(self, msg_ids: Any, to_peer: str) -> None: + async def emit_iwant(self, msg_ids: Any, to_peer: ID) -> None: """ Emit iwant message, sent to to_peer, for msg_ids """ @@ -607,7 +581,7 @@ class GossipSub(IPubsubRouter): await self.emit_control_message(control_msg, to_peer) - async def emit_graft(self, topic: str, to_peer: str) -> None: + async def emit_graft(self, topic: str, to_peer: ID) -> None: """ Emit graft message, sent to to_peer, for topic """ @@ -620,7 +594,7 @@ class GossipSub(IPubsubRouter): await self.emit_control_message(control_msg, to_peer) - async def emit_prune(self, topic: str, to_peer: str) -> None: + async def emit_prune(self, topic: str, to_peer: ID) -> None: """ Emit graft message, sent to to_peer, for topic """ @@ -634,7 +608,7 @@ class GossipSub(IPubsubRouter): await self.emit_control_message(control_msg, to_peer) async def emit_control_message( - self, control_msg: rpc_pb2.ControlMessage, to_peer: str + self, control_msg: rpc_pb2.ControlMessage, to_peer: ID ) -> None: # Add control message to packet packet: rpc_pb2.RPC = rpc_pb2.RPC() diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 6872aa3..71102f5 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -41,10 +41,8 @@ class Pubsub: my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] - # FIXME: Should be changed to `Dict[str, List[ID]]` - peer_topics: Dict[str, List[str]] - # FIXME: Should be changed to `Dict[ID, INetStream]` - peers: Dict[str, INetStream] + peer_topics: Dict[str, List[ID]] + peers: Dict[ID, INetStream] # NOTE: Be sure it is increased atomically everytime. counter: int # uint64 @@ -93,11 +91,9 @@ class Pubsub: self.my_topics = {} # Map of topic to peers to keep track of what peers are subscribed to - # FIXME: Should be changed to `Dict[str, ID]` self.peer_topics = {} # Create peers map, which maps peer_id (as string) to stream (to a given peer) - # FIXME: Should be changed to `Dict[ID, INetStream]` self.peers = {} self.counter = time.time_ns() @@ -168,7 +164,7 @@ class Pubsub: # Add peer # Map peer to stream peer_id: ID = stream.mplex_conn.peer_id - self.peers[str(peer_id)] = stream + self.peers[peer_id] = stream self.router.add_peer(peer_id, stream.get_protocol()) # Send hello packet @@ -198,7 +194,7 @@ class Pubsub: # Add Peer # Map peer to stream - self.peers[str(peer_id)] = stream + self.peers[peer_id] = stream self.router.add_peer(peer_id, stream.get_protocol()) # Send hello packet @@ -223,17 +219,16 @@ class Pubsub: :param origin_id: id of the peer who subscribe to the message :param sub_message: RPC.SubOpts """ - origin_id_str = str(origin_id) if sub_message.subscribe: if sub_message.topicid not in self.peer_topics: - self.peer_topics[sub_message.topicid] = [origin_id_str] - elif origin_id_str not in self.peer_topics[sub_message.topicid]: + self.peer_topics[sub_message.topicid] = [origin_id] + elif origin_id not in self.peer_topics[sub_message.topicid]: # Add peer to topic - self.peer_topics[sub_message.topicid].append(origin_id_str) + self.peer_topics[sub_message.topicid].append(origin_id) else: if sub_message.topicid in self.peer_topics: - if origin_id_str in self.peer_topics[sub_message.topicid]: - self.peer_topics[sub_message.topicid].remove(origin_id_str) + if origin_id in self.peer_topics[sub_message.topicid]: + self.peer_topics[sub_message.topicid].remove(origin_id) # FIXME(mhchia): Change the function name? # FIXME(mhchia): `publish_message` can be further type hinted with mypy_protobuf diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index 950d3c2..7b1c20f 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -179,12 +179,12 @@ async def perform_test_from_obj(obj, router_factory): node_map = {} pubsub_map = {} - async def add_node(node_id: str) -> None: + async def add_node(node_id_str: str) -> None: pubsub_router = router_factory(protocols=obj["supported_protocols"]) pubsub = PubsubFactory(router=pubsub_router) await pubsub.host.get_network().listen(LISTEN_MADDR) - node_map[node_id] = pubsub.host - pubsub_map[node_id] = pubsub + node_map[node_id_str] = pubsub.host + pubsub_map[node_id_str] = pubsub tasks_connect = [] for start_node_id in adj_list: diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 5dc72af..f821462 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -54,11 +54,11 @@ async def test_join(num_hosts, hosts, gossipsubs, pubsubs_gsub): for i in hosts_indices: if i in subscribed_peer_indices: - assert str(hosts[i].get_id()) in gossipsubs[central_node_index].mesh[topic] - assert str(hosts[central_node_index].get_id()) in gossipsubs[i].mesh[topic] + assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] + assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] else: assert ( - str(hosts[i].get_id()) not in gossipsubs[central_node_index].mesh[topic] + hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] ) assert topic not in gossipsubs[i].mesh @@ -89,9 +89,9 @@ async def test_leave(pubsubs_gsub): @pytest.mark.asyncio async def test_handle_graft(pubsubs_gsub, hosts, gossipsubs, event_loop, monkeypatch): index_alice = 0 - id_alice = str(hosts[index_alice].get_id()) + id_alice = hosts[index_alice].get_id() index_bob = 1 - id_bob = str(hosts[index_bob].get_id()) + id_bob = hosts[index_bob].get_id() await connect(hosts[index_alice], hosts[index_bob]) # Wait 2 seconds for heartbeat to allow mesh to connect @@ -141,9 +141,9 @@ async def test_handle_graft(pubsubs_gsub, hosts, gossipsubs, event_loop, monkeyp @pytest.mark.asyncio async def test_handle_prune(pubsubs_gsub, hosts, gossipsubs): index_alice = 0 - id_alice = str(hosts[index_alice].get_id()) + id_alice = hosts[index_alice].get_id() index_bob = 1 - id_bob = str(hosts[index_bob].get_id()) + id_bob = hosts[index_bob].get_id() topic = "test_handle_prune" for pubsub in pubsubs_gsub: diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index f7023ad..b02445e 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -60,11 +60,11 @@ async def test_peers_subscribe(pubsubs_fsub): await pubsubs_fsub[0].subscribe(TESTING_TOPIC) # Yield to let 0 notify 1 await asyncio.sleep(0.1) - assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) # Yield to let 0 notify 1 await asyncio.sleep(0.1) - assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] @pytest.mark.parametrize("num_hosts", (1,)) @@ -212,23 +212,23 @@ def test_handle_subscription(pubsubs_fsub): and TESTING_TOPIC in pubsubs_fsub[0].peer_topics ) assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 - assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] # Test: Another peer is subscribed pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) assert len(pubsubs_fsub[0].peer_topics) == 1 assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 - assert str(peer_ids[1]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] # Test: Subscribe to another topic another_topic = "ANOTHER_TOPIC" sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) assert len(pubsubs_fsub[0].peer_topics) == 2 assert another_topic in pubsubs_fsub[0].peer_topics - assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[another_topic] + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] # Test: unsubscribe unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) - assert str(peer_ids[0]) not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] @pytest.mark.parametrize("num_hosts", (1,)) @@ -261,7 +261,7 @@ async def test_handle_talk(pubsubs_fsub): @pytest.mark.asyncio async def test_message_all_peers(pubsubs_fsub, monkeypatch): peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] - mock_peers = {str(peer_id): FakeNetStream() for peer_id in peer_ids} + mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids} monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) empty_rpc = rpc_pb2.RPC()