Raise exception when topic validation failed

This commit is contained in:
NIC619 2019-08-06 12:38:31 +08:00
parent 9a1e5fe813
commit 1cea1264a4
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17
2 changed files with 15 additions and 13 deletions

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTup
from lru import LRU from lru import LRU
from libp2p.exceptions import ValidationError
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -364,7 +365,7 @@ class Pubsub:
await self.push_msg(self.host.get_id(), msg) await self.push_msg(self.host.get_id(), msg)
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> bool: async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
""" """
Validate the received message Validate the received message
:param msg_forwarder: the peer who forward us the message. :param msg_forwarder: the peer who forward us the message.
@ -380,15 +381,14 @@ class Pubsub:
for validator in sync_topic_validators: for validator in sync_topic_validators:
if not validator(msg_forwarder, msg): if not validator(msg_forwarder, msg):
return False raise ValidationError(f"Validation failed for msg={msg}")
# TODO: Implement throttle on async validators # TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0: if len(async_topic_validator_futures) > 0:
results = await asyncio.gather(*async_topic_validator_futures) results = await asyncio.gather(*async_topic_validator_futures)
return all(results) if not all(results):
else: raise ValidationError(f"Validation failed for msg={msg}")
return True
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
""" """
@ -414,8 +414,10 @@ class Pubsub:
return return
# Validate the message with registered topic validators. # Validate the message with registered topic validators.
# If the validation failed, return(i.e., don't further process the message). # If the validation failed, return(i.e., don't further process the message).
is_validation_passed = await self.validate_msg(msg_forwarder, msg) try:
if not is_validation_passed: await self.validate_msg(msg_forwarder, msg)
except ValidationError:
log.debug(f"Topic validation failed for msg={msg}")
return return
self._mark_msg_seen(msg) self._mark_msg_seen(msg)

View File

@ -4,6 +4,7 @@ from typing import NamedTuple
import pytest import pytest
from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from tests.utils import connect from tests.utils import connect
@ -191,13 +192,13 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
return True return True
def failed_sync_validator(peer_id, msg): def failed_sync_validator(peer_id, msg):
return False raise ValidationError()
async def passed_async_validator(peer_id, msg): async def passed_async_validator(peer_id, msg):
return True return True
async def failed_async_validator(peer_id, msg): async def failed_async_validator(peer_id, msg):
return False raise ValidationError()
topic_1 = "TEST_SYNC_VALIDATOR" topic_1 = "TEST_SYNC_VALIDATOR"
topic_2 = "TEST_ASYNC_VALIDATOR" topic_2 = "TEST_ASYNC_VALIDATOR"
@ -219,12 +220,11 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
seqno=b"\x00" * 8, seqno=b"\x00" * 8,
) )
is_validation_passed = await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
if is_topic_1_val_passed and is_topic_2_val_passed: if is_topic_1_val_passed and is_topic_2_val_passed:
assert is_validation_passed await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else: else:
assert not is_validation_passed with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
class FakeNetStream: class FakeNetStream: