diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b0833ee..7b725d0 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -156,7 +156,6 @@ class Pubsub: incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) - if rpc_incoming.publish: # deal with RPC.publish for msg in rpc_incoming.publish: diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 59df7d7..3413949 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,5 +1,4 @@ import asyncio -import io from typing import NamedTuple import pytest @@ -7,6 +6,7 @@ import pytest from libp2p.exceptions import ValidationError from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 +from libp2p.utils import encode_varint_prefixed from tests.utils import connect from .utils import make_pubsub_msg @@ -238,11 +238,19 @@ class FakeNetStream: def __init__(self) -> None: self._queue = asyncio.Queue() - async def read(self) -> bytes: - buf = io.BytesIO() - while not self._queue.empty(): - buf.write(await self._queue.get()) - return buf.getvalue() + async def read(self, n: int = -1) -> bytes: + buf = bytearray() + # Force to blocking wait if no data available now. + if self._queue.empty(): + first_byte = await self._queue.get() + buf.extend(first_byte) + # If `n == -1`, read until no data is in the buffer(_queue). + # Else, read until no data is in the buffer(_queue) or we have read `n` bytes. + while (n == -1) or (len(buf) < n): + if self._queue.empty(): + break + buf.extend(await self._queue.get()) + return bytes(buf) async def write(self, data: bytes) -> int: for i in data: @@ -278,7 +286,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): async def wait_for_event_occurring(event): try: - await asyncio.wait_for(event.wait(), timeout=0.01) + await asyncio.wait_for(event.wait(), timeout=1) except asyncio.TimeoutError as error: event.clear() raise asyncio.TimeoutError( @@ -295,7 +303,9 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): publish_subscribed_topic = rpc_pb2.RPC( publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] ) - await stream.write(publish_subscribed_topic.SerializeToString()) + await stream.write( + encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) + ) await wait_for_event_occurring(event_push_msg) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -307,13 +317,15 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): publish_not_subscribed_topic = rpc_pb2.RPC( publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] ) - await stream.write(publish_not_subscribed_topic.SerializeToString()) + await stream.write( + encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) + ) with pytest.raises(asyncio.TimeoutError): await wait_for_event_occurring(event_push_msg) # Test: `handle_subscription` is called when a subscription message is received. subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) - await stream.write(subscription_msg.SerializeToString()) + await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString())) await wait_for_event_occurring(event_handle_subscription) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -323,7 +335,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): # Test: `handle_rpc` is called when a control message is received. control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) - await stream.write(control_msg.SerializeToString()) + await stream.write(encode_varint_prefixed(control_msg.SerializeToString())) await wait_for_event_occurring(event_handle_rpc) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -405,9 +417,11 @@ async def test_message_all_peers(pubsubs_fsub, monkeypatch): monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) empty_rpc = rpc_pb2.RPC() - await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString()) + empty_rpc_bytes = empty_rpc.SerializeToString() + empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) + await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) for stream in mock_peers.values(): - assert (await stream.read()) == empty_rpc.SerializeToString() + assert (await stream.read()) == empty_rpc_bytes_len_prefixed @pytest.mark.parametrize("num_hosts", (1,))