diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 3e8b0d9..64a56d1 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Sequence, Set from libp2p.peer.id import ID from libp2p.typing import TProtocol +from libp2p.utils import encode_varint_prefixed from .mcache import MessageCache from .pb import rpc_pb2 @@ -169,7 +170,7 @@ class GossipSub(IPubsubRouter): # FIXME: We should add a `WriteMsg` similar to write delimited messages. # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages. - await stream.write(rpc_msg.SerializeToString()) + await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString())) def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID @@ -276,19 +277,6 @@ class GossipSub(IPubsubRouter): return "flood" return "unknown" - async def deliver_messages_to_peers( - self, peers: List[ID], msg_sender: ID, origin_id: ID, serialized_packet: bytes - ) -> None: - for peer_id_in_topic in peers: - # Forward to all peers that are not the - # message sender and are not the message origin - - if peer_id_in_topic not in (msg_sender, origin_id): - stream = self.pubsub.peers[peer_id_in_topic] - - # Publish the packet - await stream.write(serialized_packet) - # Heartbeat async def heartbeat(self) -> None: """ @@ -511,7 +499,7 @@ class GossipSub(IPubsubRouter): peer_stream = self.pubsub.peers[sender_peer_id] # 4) And write the packet to the stream - await peer_stream.write(rpc_msg) + await peer_stream.write(encode_varint_prefixed(rpc_msg)) async def handle_graft( self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID @@ -603,4 +591,4 @@ class GossipSub(IPubsubRouter): peer_stream = self.pubsub.peers[to_peer] # Write rpc to stream - await peer_stream.write(rpc_msg) + await peer_stream.write(encode_varint_prefixed(rpc_msg)) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index a02f5c3..c26bc54 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -72,7 +72,7 @@ class Pubsub: topic_validators: Dict[str, TopicValidator] - # NOTE: Be sure it is increased atomically everytime. + # TODO: Be sure it is increased atomically everytime. counter: int # uint64 def __init__( @@ -165,8 +165,6 @@ class Pubsub: for msg in rpc_incoming.publish: if not self._is_subscribed_to_msg(msg): continue - # TODO(mhchia): This will block this read_stream loop until all data are pushed. - # Should investigate further if this is an issue. asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)) if rpc_incoming.subscriptions: diff --git a/tests/interop/test_pubsub.py b/tests/interop/test_pubsub.py index 5c7f70d..53aee07 100644 --- a/tests/interop/test_pubsub.py +++ b/tests/interop/test_pubsub.py @@ -1,32 +1,164 @@ import asyncio +import functools import pytest from libp2p.peer.id import ID +from libp2p.utils import read_varint_prefixed_bytes +from libp2p.pubsub.pb import rpc_pb2 + +from p2pclient.pb import p2pd_pb2 from .utils import connect -TOPIC = "TOPIC_0123" -DATA = b"DATA_0123" +TOPIC_0 = "ABALA" +TOPIC_1 = "YOOOO" -@pytest.mark.parametrize("num_hosts", (1,)) +async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]": + reader, writer = await p2pd.control.pubsub_subscribe(topic) + + queue = asyncio.Queue() + + async def _read_pubsub_msg() -> None: + writer_closed_task = asyncio.ensure_future(writer.wait_closed()) + + while True: + done, pending = await asyncio.wait( + [read_varint_prefixed_bytes(reader), writer_closed_task], + return_when=asyncio.FIRST_COMPLETED, + ) + done_tasks = tuple(done) + if writer.is_closing(): + return + read_task = done_tasks[0] + # Sanity check + assert read_task._coro.__name__ == "read_varint_prefixed_bytes" + msg_bytes = read_task.result() + ps_msg = p2pd_pb2.PSMessage() + ps_msg.ParseFromString(msg_bytes) + # Fill in the message used in py-libp2p + msg = rpc_pb2.Message( + from_id=ps_msg.from_field, + data=ps_msg.data, + seqno=ps_msg.seqno, + topicIDs=ps_msg.topicIDs, + signature=ps_msg.signature, + key=ps_msg.key, + ) + queue.put_nowait(msg) + + asyncio.ensure_future(_read_pubsub_msg()) + await asyncio.sleep(0) + return queue + + +def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None: + assert msg.data == data and msg.from_id == from_peer_id + + +@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),)) @pytest.mark.asyncio -async def test_pubsub_subscribe(pubsubs_gsub, p2pds): - # await connect(pubsubs_gsub[0].host, p2pds[0]) - await connect(p2pds[0], pubsubs_gsub[0].host) - peers = await p2pds[0].control.pubsub_list_peers("") - assert pubsubs_gsub[0].host.get_id() in peers - # FIXME: - assert p2pds[0].peer_id in pubsubs_gsub[0].peers +async def test_pubsub(pubsubs_gsub, p2pds): + # + # Test: Recognize pubsub peers on connection. + # + py_pubsub = pubsubs_gsub[0] + # go0 <-> py <-> go1 + await connect(p2pds[0], py_pubsub.host) + await connect(py_pubsub.host, p2pds[1]) + py_peer_id = py_pubsub.host.get_id() + # Check pubsub peers + pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("") + assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id + pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("") + assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id + assert ( + len(py_pubsub.peers) == 2 + and p2pds[0].peer_id in py_pubsub.peers + and p2pds[1].peer_id in py_pubsub.peers + ) - sub = await pubsubs_gsub[0].subscribe(TOPIC) - peers_topic = await p2pds[0].control.pubsub_list_peers(TOPIC) + # + # Test: `subscribe`. + # + # (name, topics) + # (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1]) + sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0) + sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1) + sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0) + sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1) + sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1) + # Check topic peers await asyncio.sleep(0.1) - assert pubsubs_gsub[0].host.get_id() in peers_topic + # go_0 + go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0) + assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0] + go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1) + assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0] + # py + py_topic_0_peers = py_pubsub.peer_topics[TOPIC_0] + assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0] + # go_1 + go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1) + assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0] - await p2pds[0].control.pubsub_publish(TOPIC, DATA) - msg = await sub.get() - assert ID(msg.from_id) == p2pds[0].peer_id - assert msg.data == DATA - assert len(msg.topicIDs) == 1 and msg.topicIDs[0] == TOPIC + # + # Test: `publish` + # + # 1. py publishes + # - 1.1. py publishes data_11 to topic_0, py and go_0 receives. + # - 1.2. py publishes data_12 to topic_1, all receive. + # 2. go publishes + # - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + # - 2.2. go_1 publishes data_22 to topic_1, all receive. + + # 1.1. py publishes data_11 to topic_0, py and go_0 receives. + data_11 = b"data_11" + await py_pubsub.publish(TOPIC_0, data_11) + validate_11 = functools.partial( + validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id + ) + validate_11(await sub_py_topic_0.get()) + validate_11(await sub_go_0_topic_0.get()) + + # 1.2. py publishes data_12 to topic_1, all receive. + data_12 = b"data_12" + validate_12 = functools.partial( + validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id + ) + await py_pubsub.publish(TOPIC_1, data_12) + validate_12(await sub_py_topic_1.get()) + validate_12(await sub_go_0_topic_1.get()) + validate_12(await sub_go_1_topic_1.get()) + + # 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive. + data_21 = b"data_21" + validate_21 = functools.partial( + validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id + ) + await p2pds[0].control.pubsub_publish(TOPIC_0, data_21) + validate_21(await sub_py_topic_0.get()) + validate_21(await sub_go_0_topic_0.get()) + + # 2.2. go_1 publishes data_22 to topic_1, all receive. + data_22 = b"data_22" + validate_22 = functools.partial( + validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id + ) + await p2pds[1].control.pubsub_publish(TOPIC_1, data_22) + validate_22(await sub_py_topic_1.get()) + validate_22(await sub_go_0_topic_1.get()) + validate_22(await sub_go_1_topic_1.get()) + + # + # Test: `unsubscribe` and re`subscribe` + # + await py_pubsub.unsubscribe(TOPIC_0) + await asyncio.sleep(0.1) + assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0)) + await py_pubsub.subscribe(TOPIC_0) + await asyncio.sleep(0.1) + assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0)) + assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))