Rename to set_topic_validator
and add test
This commit is contained in:
parent
b1f4813195
commit
cf69f7e800
|
@ -161,7 +161,7 @@ class Pubsub:
|
|||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def add_topic_validator(
|
||||
def set_topic_validator(
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
|
||||
|
|
|
@ -84,6 +84,54 @@ async def test_get_hello_packet(pubsubs_fsub):
|
|||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_topic_validator(pubsubs_fsub):
|
||||
|
||||
is_sync_validator_called = False
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
nonlocal is_sync_validator_called
|
||||
is_sync_validator_called = True
|
||||
|
||||
is_async_validator_called = False
|
||||
|
||||
async def async_validator(peer_id, msg):
|
||||
nonlocal is_async_validator_called
|
||||
is_async_validator_called = True
|
||||
|
||||
topic = "TEST_VALIDATOR"
|
||||
|
||||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
# Register sync validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||||
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert not topic_validator.is_async
|
||||
|
||||
# Validate with sync validator
|
||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
|
||||
assert is_sync_validator_called
|
||||
assert not is_async_validator_called
|
||||
|
||||
# Register with async validator
|
||||
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||||
|
||||
is_sync_validator_called = False
|
||||
assert topic in pubsubs_fsub[0].topic_validators
|
||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||
assert topic_validator.is_async
|
||||
|
||||
# Validate with async validator
|
||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
|
||||
assert is_async_validator_called
|
||||
assert not is_sync_validator_called
|
||||
|
||||
|
||||
class FakeNetStream:
|
||||
_queue: asyncio.Queue
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user