Add get_msg_validators and test

This commit is contained in:
NIC619 2019-08-04 11:23:20 +08:00
parent 1ed14d0cc8
commit f8ca4fa1ef
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17
2 changed files with 57 additions and 0 deletions

View File

@ -170,6 +170,13 @@ class Pubsub:
if topic in self.topic_validators: if topic in self.topic_validators:
del self.topic_validators[topic] del self.topic_validators[topic]
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
return (
self.topic_validators[topic]
for topic in msg.topicIDs
if topic in self.topic_validators
)
async def stream_handler(self, stream: INetStream) -> None: async def stream_handler(self, stream: INetStream) -> None:
""" """
Stream handler for pubsub. Gets invoked whenever a new stream is created Stream handler for pubsub. Gets invoked whenever a new stream is created

View File

@ -136,6 +136,56 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
assert topic not in pubsubs_fsub[0].topic_validators assert topic not in pubsubs_fsub[0].topic_validators
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_get_msg_validators(pubsubs_fsub):
times_sync_validator_called = 0
def sync_validator(peer_id, msg):
nonlocal times_sync_validator_called
times_sync_validator_called += 1
times_async_validator_called = 0
async def async_validator(peer_id, msg):
nonlocal times_async_validator_called
times_async_validator_called += 1
topic_1 = "TEST_VALIDATOR_1"
topic_2 = "TEST_VALIDATOR_2"
topic_3 = "TEST_VALIDATOR_3"
# Register sync validator for topic 1 and 2
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
assert topic_1 in pubsubs_fsub[0].topic_validators
assert topic_2 in pubsubs_fsub[0].topic_validators
# Register async validator for topic 3
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
assert topic_3 in pubsubs_fsub[0].topic_validators
msg = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[topic_1, topic_2, topic_3],
data=b"1234",
seqno=b"\x00" * 8,
)
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
for topic_validator in topic_validators:
if topic_validator.is_async:
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
else:
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
assert times_sync_validator_called == 2
assert times_async_validator_called == 1
class FakeNetStream: class FakeNetStream:
_queue: asyncio.Queue _queue: asyncio.Queue