Fix pubsub tests

This commit is contained in:
mhchia 2019-09-04 15:33:07 +08:00
parent 961e51fa2e
commit 677531db76
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 27 additions and 14 deletions

View File

@ -156,7 +156,6 @@ class Pubsub:
incoming: bytes = await read_varint_prefixed_bytes(stream) incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish: if rpc_incoming.publish:
# deal with RPC.publish # deal with RPC.publish
for msg in rpc_incoming.publish: for msg in rpc_incoming.publish:

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import io
from typing import NamedTuple from typing import NamedTuple
import pytest import pytest
@ -7,6 +6,7 @@ import pytest
from libp2p.exceptions import ValidationError from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from libp2p.utils import encode_varint_prefixed
from tests.utils import connect from tests.utils import connect
from .utils import make_pubsub_msg from .utils import make_pubsub_msg
@ -238,11 +238,19 @@ class FakeNetStream:
def __init__(self) -> None: def __init__(self) -> None:
self._queue = asyncio.Queue() self._queue = asyncio.Queue()
async def read(self) -> bytes: async def read(self, n: int = -1) -> bytes:
buf = io.BytesIO() buf = bytearray()
while not self._queue.empty(): # Force to blocking wait if no data available now.
buf.write(await self._queue.get()) if self._queue.empty():
return buf.getvalue() first_byte = await self._queue.get()
buf.extend(first_byte)
# If `n == -1`, read until no data is in the buffer(_queue).
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
while (n == -1) or (len(buf) < n):
if self._queue.empty():
break
buf.extend(await self._queue.get())
return bytes(buf)
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
for i in data: for i in data:
@ -278,7 +286,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
async def wait_for_event_occurring(event): async def wait_for_event_occurring(event):
try: try:
await asyncio.wait_for(event.wait(), timeout=0.01) await asyncio.wait_for(event.wait(), timeout=1)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
event.clear() event.clear()
raise asyncio.TimeoutError( raise asyncio.TimeoutError(
@ -295,7 +303,9 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
publish_subscribed_topic = rpc_pb2.RPC( publish_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
) )
await stream.write(publish_subscribed_topic.SerializeToString()) await stream.write(
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
)
await wait_for_event_occurring(event_push_msg) await wait_for_event_occurring(event_push_msg)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -307,13 +317,15 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
publish_not_subscribed_topic = rpc_pb2.RPC( publish_not_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
) )
await stream.write(publish_not_subscribed_topic.SerializeToString()) await stream.write(
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg) await wait_for_event_occurring(event_push_msg)
# Test: `handle_subscription` is called when a subscription message is received. # Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
await stream.write(subscription_msg.SerializeToString()) await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_subscription) await wait_for_event_occurring(event_handle_subscription)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -323,7 +335,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
# Test: `handle_rpc` is called when a control message is received. # Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
await stream.write(control_msg.SerializeToString()) await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_rpc) await wait_for_event_occurring(event_handle_rpc)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -405,9 +417,11 @@ async def test_message_all_peers(pubsubs_fsub, monkeypatch):
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC() empty_rpc = rpc_pb2.RPC()
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString()) empty_rpc_bytes = empty_rpc.SerializeToString()
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
for stream in mock_peers.values(): for stream in mock_peers.values():
assert (await stream.read()) == empty_rpc.SerializeToString() assert (await stream.read()) == empty_rpc_bytes_len_prefixed
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize("num_hosts", (1,))