Add tests for Pubsub

- `test_handle_subscription`
- `test_handle_talk`
- `test_message_all_peers`
This commit is contained in:
mhchia 2019-07-30 13:33:48 +08:00
parent 3a42d72cd9
commit 9683d5e8ac
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 105 additions and 8 deletions

View File

@ -323,7 +323,7 @@ class Pubsub:
""" """
# Broadcast message # Broadcast message
for _, stream in self.peers.items(): for stream in self.peers.values():
# Write message to stream # Write message to stream
await stream.write(raw_msg) await stream.write(raw_msg)

View File

@ -138,7 +138,7 @@ class FakeNetStream:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
s = FakeNetStream() stream = FakeNetStream()
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
@ -172,7 +172,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
event.clear() event.clear()
# Kick off the task `continuously_read_stream` # Kick off the task `continuously_read_stream`
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(s)) task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))
# Test: `push_msg` is called when publishing to a subscribed topic. # Test: `push_msg` is called when publishing to a subscribed topic.
publish_subscribed_topic = rpc_pb2.RPC( publish_subscribed_topic = rpc_pb2.RPC(
@ -180,7 +180,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
topicIDs=[TESTING_TOPIC] topicIDs=[TESTING_TOPIC]
)], )],
) )
await s.write(publish_subscribed_topic.SerializeToString()) await stream.write(publish_subscribed_topic.SerializeToString())
await wait_for_event_occurring(event_push_msg) await wait_for_event_occurring(event_push_msg)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -194,7 +194,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
topicIDs=["NOT_SUBSCRIBED"] topicIDs=["NOT_SUBSCRIBED"]
)], )],
) )
await s.write(publish_not_subscribed_topic.SerializeToString()) await stream.write(publish_not_subscribed_topic.SerializeToString())
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg) await wait_for_event_occurring(event_push_msg)
@ -202,7 +202,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
subscription_msg = rpc_pb2.RPC( subscription_msg = rpc_pb2.RPC(
subscriptions=[rpc_pb2.RPC.SubOpts()], subscriptions=[rpc_pb2.RPC.SubOpts()],
) )
await s.write(subscription_msg.SerializeToString()) await stream.write(subscription_msg.SerializeToString())
await wait_for_event_occurring(event_handle_subscription) await wait_for_event_occurring(event_handle_subscription)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -212,7 +212,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
# Test: `handle_rpc` is called when a control message is received. # Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
await s.write(control_msg.SerializeToString()) await stream.write(control_msg.SerializeToString())
await wait_for_event_occurring(event_handle_rpc) await wait_for_event_occurring(event_handle_rpc)
# Make sure the other events are not emitted. # Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
@ -223,9 +223,106 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
task.cancel() task.cancel()
# TODO: Add the following tests after they are aligned with Go.
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
# - `test_stream_handler`
# - `test_handle_peer_queue`
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_hosts", "num_hosts",
(2,), (1,),
)
def test_handle_subscription(pubsubs_fsub):
assert len(pubsubs_fsub[0].peer_topics) == 0
sub_msg_0 = rpc_pb2.RPC.SubOpts(
subscribe=True,
topicid=TESTING_TOPIC,
)
peer_ids = [
ID(b"\x12\x20" + i.to_bytes(32, "big"))
for i in range(2)
]
# Test: One peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
assert len(pubsubs_fsub[0].peer_topics) == 1 and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Another peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
assert len(pubsubs_fsub[0].peer_topics) == 1
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
assert str(peer_ids[1]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
# Test: Subscribe to another topic
another_topic = "ANOTHER_TOPIC"
sub_msg_1 = rpc_pb2.RPC.SubOpts(
subscribe=True,
topicid=another_topic,
)
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
assert len(pubsubs_fsub[0].peer_topics) == 2
assert another_topic in pubsubs_fsub[0].peer_topics
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[another_topic]
# Test: unsubscribe
unsub_msg = rpc_pb2.RPC.SubOpts(
subscribe=False,
topicid=TESTING_TOPIC,
)
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
assert str(peer_ids[0]) not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_handle_talk(pubsubs_fsub):
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
data=b"1234",
seqno=b"\x00" * 8,
)
await pubsubs_fsub[0].handle_talk(msg_0)
msg_1 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=["NOT_SUBSCRIBED"],
data=b"1234",
seqno=b"\x11" * 8,
)
await pubsubs_fsub[0].handle_talk(msg_1)
assert len(pubsubs_fsub[0].my_topics) == 1 and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC]
assert sub.qsize() == 1
assert (await sub.get()) == msg_0
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
peer_ids = [
ID(b"\x12\x20" + i.to_bytes(32, "big"))
for i in range(10)
]
mock_peers = {
str(peer_id): FakeNetStream()
for peer_id in peer_ids
}
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC()
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString())
for stream in mock_peers.values():
assert (await stream.read()) == empty_rpc.SerializeToString()
@pytest.mark.parametrize(
"num_hosts",
(1,),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publish(pubsubs_fsub, monkeypatch): async def test_publish(pubsubs_fsub, monkeypatch):