py-libp2p/tests/pubsub/test_mcache.py
mhchia 550289a439
Combine test_subscription.py to test_pubsub.py
And add a bunch of tests for pubsub
2019-07-31 14:15:10 +08:00

132 lines
2.6 KiB
Python

import pytest
from libp2p.pubsub.mcache import MessageCache
# pylint: disable=too-few-public-methods
class Msg:
def __init__(self, topicIDs, seqno, from_id):
# pylint: disable=invalid-name
self.topicIDs = topicIDs
self.seqno = seqno
self.from_id = from_id
@pytest.mark.asyncio
async def test_mcache():
# Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5)
msgs = []
for i in range(60):
msgs.append(Msg(["test"], i, "test"))
for i in range(10):
mcache.put(msgs[i])
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
# successful read
assert get_msg == msg
gids = mcache.window('test')
assert len(gids) == 10
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
mcache.shift()
for i in range(10, 20):
mcache.put(msgs[i])
for i in range(20):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
assert get_msg == msg
gids = mcache.window('test')
assert len(gids) == 20
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[10 + i]
for i in range(10, 20):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i - 10]
mcache.shift()
for i in range(20, 30):
mcache.put(msgs[i])
mcache.shift()
for i in range(30, 40):
mcache.put(msgs[i])
mcache.shift()
for i in range(40, 50):
mcache.put(msgs[i])
mcache.shift()
for i in range(50, 60):
mcache.put(msgs[i])
assert len(mcache.msgs) == 50
for i in range(10):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
# Should be evicted from cache
assert not get_msg
for i in range(10, 60):
msg = msgs[i]
mid = (msg.seqno, msg.from_id)
get_msg = mcache.get(mid)
assert get_msg == msg
gids = mcache.window('test')
assert len(gids) == 30
for i in range(10):
msg = msgs[50 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
for i in range(10, 20):
msg = msgs[30 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]
for i in range(20, 30):
msg = msgs[10 + i]
mid = (msg.seqno, msg.from_id)
assert mid == gids[i]