diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index dd51be2..37eb932 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTup from lru import LRU +from libp2p.exceptions import ValidationError from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID @@ -364,7 +365,7 @@ class Pubsub: await self.push_msg(self.host.get_id(), msg) - async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> bool: + async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ Validate the received message :param msg_forwarder: the peer who forward us the message. @@ -380,15 +381,14 @@ class Pubsub: for validator in sync_topic_validators: if not validator(msg_forwarder, msg): - return False + raise ValidationError(f"Validation failed for msg={msg}") # TODO: Implement throttle on async validators if len(async_topic_validator_futures) > 0: results = await asyncio.gather(*async_topic_validator_futures) - return all(results) - else: - return True + if not all(results): + raise ValidationError(f"Validation failed for msg={msg}") async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ @@ -414,8 +414,10 @@ class Pubsub: return # Validate the message with registered topic validators. # If the validation failed, return(i.e., don't further process the message). - is_validation_passed = await self.validate_msg(msg_forwarder, msg) - if not is_validation_passed: + try: + await self.validate_msg(msg_forwarder, msg) + except ValidationError: + log.debug(f"Topic validation failed for msg={msg}") return self._mark_msg_seen(msg) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index ad9dd43..0799c34 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -4,6 +4,7 @@ from typing import NamedTuple import pytest +from libp2p.exceptions import ValidationError from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 from tests.utils import connect @@ -191,13 +192,13 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_ return True def failed_sync_validator(peer_id, msg): - return False + raise ValidationError() async def passed_async_validator(peer_id, msg): return True async def failed_async_validator(peer_id, msg): - return False + raise ValidationError() topic_1 = "TEST_SYNC_VALIDATOR" topic_2 = "TEST_ASYNC_VALIDATOR" @@ -219,12 +220,11 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_ seqno=b"\x00" * 8, ) - is_validation_passed = await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - if is_topic_1_val_passed and is_topic_2_val_passed: - assert is_validation_passed + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) else: - assert not is_validation_passed + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) class FakeNetStream: