Add get_msg_validators
and test
This commit is contained in:
parent
1ed14d0cc8
commit
f8ca4fa1ef
|
@ -170,6 +170,13 @@ class Pubsub:
|
||||||
if topic in self.topic_validators:
|
if topic in self.topic_validators:
|
||||||
del self.topic_validators[topic]
|
del self.topic_validators[topic]
|
||||||
|
|
||||||
|
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||||
|
return (
|
||||||
|
self.topic_validators[topic]
|
||||||
|
for topic in msg.topicIDs
|
||||||
|
if topic in self.topic_validators
|
||||||
|
)
|
||||||
|
|
||||||
async def stream_handler(self, stream: INetStream) -> None:
|
async def stream_handler(self, stream: INetStream) -> None:
|
||||||
"""
|
"""
|
||||||
Stream handler for pubsub. Gets invoked whenever a new stream is created
|
Stream handler for pubsub. Gets invoked whenever a new stream is created
|
||||||
|
|
|
@ -136,6 +136,56 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
||||||
assert topic not in pubsubs_fsub[0].topic_validators
|
assert topic not in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_msg_validators(pubsubs_fsub):
|
||||||
|
|
||||||
|
times_sync_validator_called = 0
|
||||||
|
|
||||||
|
def sync_validator(peer_id, msg):
|
||||||
|
nonlocal times_sync_validator_called
|
||||||
|
times_sync_validator_called += 1
|
||||||
|
|
||||||
|
times_async_validator_called = 0
|
||||||
|
|
||||||
|
async def async_validator(peer_id, msg):
|
||||||
|
nonlocal times_async_validator_called
|
||||||
|
times_async_validator_called += 1
|
||||||
|
|
||||||
|
topic_1 = "TEST_VALIDATOR_1"
|
||||||
|
topic_2 = "TEST_VALIDATOR_2"
|
||||||
|
topic_3 = "TEST_VALIDATOR_3"
|
||||||
|
|
||||||
|
# Register sync validator for topic 1 and 2
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||||
|
|
||||||
|
assert topic_1 in pubsubs_fsub[0].topic_validators
|
||||||
|
assert topic_2 in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
|
# Register async validator for topic 3
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||||
|
|
||||||
|
assert topic_3 in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
|
msg = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[topic_1, topic_2, topic_3],
|
||||||
|
data=b"1234",
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||||
|
for topic_validator in topic_validators:
|
||||||
|
if topic_validator.is_async:
|
||||||
|
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
else:
|
||||||
|
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
|
||||||
|
assert times_sync_validator_called == 2
|
||||||
|
assert times_async_validator_called == 1
|
||||||
|
|
||||||
|
|
||||||
class FakeNetStream:
|
class FakeNetStream:
|
||||||
_queue: asyncio.Queue
|
_queue: asyncio.Queue
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user