diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 31ce4c7..671df9b 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -161,7 +161,7 @@ class Pubsub: # Force context switch await asyncio.sleep(0) - def add_topic_validator( + def set_topic_validator( self, topic: str, validator: ValidatorFn, is_async_validator: bool ) -> None: self.topic_validators[topic] = TopicValidator(validator, is_async_validator) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 530677b..438c48b 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -84,6 +84,54 @@ async def test_get_hello_packet(pubsubs_fsub): assert topic in topic_ids_in_hello +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.asyncio +async def test_add_topic_validator(pubsubs_fsub): + + is_sync_validator_called = False + + def sync_validator(peer_id, msg): + nonlocal is_sync_validator_called + is_sync_validator_called = True + + is_async_validator_called = False + + async def async_validator(peer_id, msg): + nonlocal is_async_validator_called + is_async_validator_called = True + + topic = "TEST_VALIDATOR" + + assert topic not in pubsubs_fsub[0].topic_validators + + # Register sync validator + pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) + + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert not topic_validator.is_async + + # Validate with sync validator + topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + + assert is_sync_validator_called + assert not is_async_validator_called + + # Register with async validator + pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) + + is_sync_validator_called = False + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert topic_validator.is_async + + # Validate with async validator + await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + + assert is_async_validator_called + assert not is_sync_validator_called + + class FakeNetStream: _queue: asyncio.Queue