Fix pubsub tests

This commit is contained in:
mhchia 2019-09-04 15:33:07 +08:00
parent 961e51fa2e
commit 677531db76
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 27 additions and 14 deletions

View File

@ -156,7 +156,6 @@ class Pubsub:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish:
# deal with RPC.publish
for msg in rpc_incoming.publish:

View File

@ -1,5 +1,4 @@
import asyncio
import io
from typing import NamedTuple
import pytest
@ -7,6 +6,7 @@ import pytest
from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2
from libp2p.utils import encode_varint_prefixed
from tests.utils import connect
from .utils import make_pubsub_msg
@ -238,11 +238,19 @@ class FakeNetStream:
def __init__(self) -> None:
self._queue = asyncio.Queue()
async def read(self) -> bytes:
buf = io.BytesIO()
while not self._queue.empty():
buf.write(await self._queue.get())
return buf.getvalue()
async def read(self, n: int = -1) -> bytes:
buf = bytearray()
# Force to blocking wait if no data available now.
if self._queue.empty():
first_byte = await self._queue.get()
buf.extend(first_byte)
# If `n == -1`, read until no data is in the buffer(_queue).
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
while (n == -1) or (len(buf) < n):
if self._queue.empty():
break
buf.extend(await self._queue.get())
return bytes(buf)
async def write(self, data: bytes) -> int:
for i in data:
@ -278,7 +286,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
async def wait_for_event_occurring(event):
try:
await asyncio.wait_for(event.wait(), timeout=0.01)
await asyncio.wait_for(event.wait(), timeout=1)
except asyncio.TimeoutError as error:
event.clear()
raise asyncio.TimeoutError(
@ -295,7 +303,9 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
publish_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
)
await stream.write(publish_subscribed_topic.SerializeToString())
await stream.write(
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
)
await wait_for_event_occurring(event_push_msg)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
@ -307,13 +317,15 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
publish_not_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
)
await stream.write(publish_not_subscribed_topic.SerializeToString())
await stream.write(
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
# Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
await stream.write(subscription_msg.SerializeToString())
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_subscription)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
@ -323,7 +335,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
# Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
await stream.write(control_msg.SerializeToString())
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_rpc)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
@ -405,9 +417,11 @@ async def test_message_all_peers(pubsubs_fsub, monkeypatch):
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC()
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString())
empty_rpc_bytes = empty_rpc.SerializeToString()
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
for stream in mock_peers.values():
assert (await stream.read()) == empty_rpc.SerializeToString()
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
@pytest.mark.parametrize("num_hosts", (1,))