Fix pubsub tests
This commit is contained in:
parent
961e51fa2e
commit
677531db76
|
@ -156,7 +156,6 @@ class Pubsub:
|
|||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
||||
if rpc_incoming.publish:
|
||||
# deal with RPC.publish
|
||||
for msg in rpc_incoming.publish:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import io
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
|
@ -7,6 +6,7 @@ import pytest
|
|||
from libp2p.exceptions import ValidationError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
from tests.utils import connect
|
||||
|
||||
from .utils import make_pubsub_msg
|
||||
|
@ -238,11 +238,19 @@ class FakeNetStream:
|
|||
def __init__(self) -> None:
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def read(self) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
while not self._queue.empty():
|
||||
buf.write(await self._queue.get())
|
||||
return buf.getvalue()
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
buf = bytearray()
|
||||
# Force to blocking wait if no data available now.
|
||||
if self._queue.empty():
|
||||
first_byte = await self._queue.get()
|
||||
buf.extend(first_byte)
|
||||
# If `n == -1`, read until no data is in the buffer(_queue).
|
||||
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
|
||||
while (n == -1) or (len(buf) < n):
|
||||
if self._queue.empty():
|
||||
break
|
||||
buf.extend(await self._queue.get())
|
||||
return bytes(buf)
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
for i in data:
|
||||
|
@ -278,7 +286,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|||
|
||||
async def wait_for_event_occurring(event):
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=0.01)
|
||||
await asyncio.wait_for(event.wait(), timeout=1)
|
||||
except asyncio.TimeoutError as error:
|
||||
event.clear()
|
||||
raise asyncio.TimeoutError(
|
||||
|
@ -295,7 +303,9 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|||
publish_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||
)
|
||||
await stream.write(publish_subscribed_topic.SerializeToString())
|
||||
await stream.write(
|
||||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
|
@ -307,13 +317,15 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||
)
|
||||
await stream.write(publish_not_subscribed_topic.SerializeToString())
|
||||
await stream.write(
|
||||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||
)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
|
||||
# Test: `handle_subscription` is called when a subscription message is received.
|
||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||
await stream.write(subscription_msg.SerializeToString())
|
||||
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
|
@ -323,7 +335,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|||
|
||||
# Test: `handle_rpc` is called when a control message is received.
|
||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||
await stream.write(control_msg.SerializeToString())
|
||||
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
|
@ -405,9 +417,11 @@ async def test_message_all_peers(pubsubs_fsub, monkeypatch):
|
|||
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
|
||||
empty_rpc = rpc_pb2.RPC()
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString())
|
||||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||
for stream in mock_peers.values():
|
||||
assert (await stream.read()) == empty_rpc.SerializeToString()
|
||||
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
|
|
Loading…
Reference in New Issue
Block a user