Add get_msg_validators
and test
This commit is contained in:
parent
1ed14d0cc8
commit
f8ca4fa1ef
|
@ -170,6 +170,13 @@ class Pubsub:
|
|||
if topic in self.topic_validators:
|
||||
del self.topic_validators[topic]
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||
return (
|
||||
self.topic_validators[topic]
|
||||
for topic in msg.topicIDs
|
||||
if topic in self.topic_validators
|
||||
)
|
||||
|
||||
async def stream_handler(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Stream handler for pubsub. Gets invoked whenever a new stream is created
|
||||
|
|
|
@ -136,6 +136,56 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
|||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_msg_validators(pubsubs_fsub):
|
||||
|
||||
times_sync_validator_called = 0
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal times_sync_validator_called
|
||||
times_sync_validator_called += 1
|
||||
|
||||
times_async_validator_called = 0
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal times_async_validator_called
|
||||
times_async_validator_called += 1
|
||||
|
||||
topic_1 = "TEST_VALIDATOR_1"
|
||||
topic_2 = "TEST_VALIDATOR_2"
|
||||
topic_3 = "TEST_VALIDATOR_3"
|
||||
|
||||
# Register sync validator for topic 1 and 2
|
||||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||
|
||||
assert topic_1 in pubsubs_fsub[0].topic_validators
|
||||
assert topic_2 in pubsubs_fsub[0].topic_validators
|
||||
|
||||
# Register async validator for topic 3
|
||||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||
|
||||
assert topic_3 in pubsubs_fsub[0].topic_validators
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[topic_1, topic_2, topic_3],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||
for topic_validator in topic_validators:
|
||||
if topic_validator.is_async:
|
||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
else:
|
||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
|
||||
assert times_sync_validator_called == 2
|
||||
assert times_async_validator_called == 1
|
||||
|
||||
|
||||
class FakeNetStream:
|
||||
_queue: asyncio.Queue
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user