Fix pubsub tests
This commit is contained in:
parent
961e51fa2e
commit
677531db76
|
@ -156,7 +156,6 @@ class Pubsub:
|
||||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
rpc_incoming.ParseFromString(incoming)
|
rpc_incoming.ParseFromString(incoming)
|
||||||
|
|
||||||
if rpc_incoming.publish:
|
if rpc_incoming.publish:
|
||||||
# deal with RPC.publish
|
# deal with RPC.publish
|
||||||
for msg in rpc_incoming.publish:
|
for msg in rpc_incoming.publish:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -7,6 +6,7 @@ import pytest
|
||||||
from libp2p.exceptions import ValidationError
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
from libp2p.pubsub.pb import rpc_pb2
|
||||||
|
from libp2p.utils import encode_varint_prefixed
|
||||||
from tests.utils import connect
|
from tests.utils import connect
|
||||||
|
|
||||||
from .utils import make_pubsub_msg
|
from .utils import make_pubsub_msg
|
||||||
|
@ -238,11 +238,19 @@ class FakeNetStream:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._queue = asyncio.Queue()
|
self._queue = asyncio.Queue()
|
||||||
|
|
||||||
async def read(self) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
buf = io.BytesIO()
|
buf = bytearray()
|
||||||
while not self._queue.empty():
|
# Force to blocking wait if no data available now.
|
||||||
buf.write(await self._queue.get())
|
if self._queue.empty():
|
||||||
return buf.getvalue()
|
first_byte = await self._queue.get()
|
||||||
|
buf.extend(first_byte)
|
||||||
|
# If `n == -1`, read until no data is in the buffer(_queue).
|
||||||
|
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
|
||||||
|
while (n == -1) or (len(buf) < n):
|
||||||
|
if self._queue.empty():
|
||||||
|
break
|
||||||
|
buf.extend(await self._queue.get())
|
||||||
|
return bytes(buf)
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
for i in data:
|
for i in data:
|
||||||
|
@ -278,7 +286,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
|
|
||||||
async def wait_for_event_occurring(event):
|
async def wait_for_event_occurring(event):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout=0.01)
|
await asyncio.wait_for(event.wait(), timeout=1)
|
||||||
except asyncio.TimeoutError as error:
|
except asyncio.TimeoutError as error:
|
||||||
event.clear()
|
event.clear()
|
||||||
raise asyncio.TimeoutError(
|
raise asyncio.TimeoutError(
|
||||||
|
@ -295,7 +303,9 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
publish_subscribed_topic = rpc_pb2.RPC(
|
publish_subscribed_topic = rpc_pb2.RPC(
|
||||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||||
)
|
)
|
||||||
await stream.write(publish_subscribed_topic.SerializeToString())
|
await stream.write(
|
||||||
|
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||||
|
)
|
||||||
await wait_for_event_occurring(event_push_msg)
|
await wait_for_event_occurring(event_push_msg)
|
||||||
# Make sure the other events are not emitted.
|
# Make sure the other events are not emitted.
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
@ -307,13 +317,15 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||||
)
|
)
|
||||||
await stream.write(publish_not_subscribed_topic.SerializeToString())
|
await stream.write(
|
||||||
|
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||||
|
)
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
await wait_for_event_occurring(event_push_msg)
|
await wait_for_event_occurring(event_push_msg)
|
||||||
|
|
||||||
# Test: `handle_subscription` is called when a subscription message is received.
|
# Test: `handle_subscription` is called when a subscription message is received.
|
||||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||||
await stream.write(subscription_msg.SerializeToString())
|
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
|
||||||
await wait_for_event_occurring(event_handle_subscription)
|
await wait_for_event_occurring(event_handle_subscription)
|
||||||
# Make sure the other events are not emitted.
|
# Make sure the other events are not emitted.
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
@ -323,7 +335,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
|
|
||||||
# Test: `handle_rpc` is called when a control message is received.
|
# Test: `handle_rpc` is called when a control message is received.
|
||||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||||
await stream.write(control_msg.SerializeToString())
|
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
|
||||||
await wait_for_event_occurring(event_handle_rpc)
|
await wait_for_event_occurring(event_handle_rpc)
|
||||||
# Make sure the other events are not emitted.
|
# Make sure the other events are not emitted.
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
@ -405,9 +417,11 @@ async def test_message_all_peers(pubsubs_fsub, monkeypatch):
|
||||||
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||||
|
|
||||||
empty_rpc = rpc_pb2.RPC()
|
empty_rpc = rpc_pb2.RPC()
|
||||||
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString())
|
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||||
|
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||||
|
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||||
for stream in mock_peers.values():
|
for stream in mock_peers.values():
|
||||||
assert (await stream.read()) == empty_rpc.SerializeToString()
|
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user