Raise exception when topic validation failed
This commit is contained in:
parent
9a1e5fe813
commit
1cea1264a4
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTup
|
||||||
|
|
||||||
from lru import LRU
|
from lru import LRU
|
||||||
|
|
||||||
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.host.host_interface import IHost
|
from libp2p.host.host_interface import IHost
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
|
@ -364,7 +365,7 @@ class Pubsub:
|
||||||
|
|
||||||
await self.push_msg(self.host.get_id(), msg)
|
await self.push_msg(self.host.get_id(), msg)
|
||||||
|
|
||||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> bool:
|
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
Validate the received message
|
Validate the received message
|
||||||
:param msg_forwarder: the peer who forward us the message.
|
:param msg_forwarder: the peer who forward us the message.
|
||||||
|
@ -380,15 +381,14 @@ class Pubsub:
|
||||||
|
|
||||||
for validator in sync_topic_validators:
|
for validator in sync_topic_validators:
|
||||||
if not validator(msg_forwarder, msg):
|
if not validator(msg_forwarder, msg):
|
||||||
return False
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
# TODO: Implement throttle on async validators
|
# TODO: Implement throttle on async validators
|
||||||
|
|
||||||
if len(async_topic_validator_futures) > 0:
|
if len(async_topic_validator_futures) > 0:
|
||||||
results = await asyncio.gather(*async_topic_validator_futures)
|
results = await asyncio.gather(*async_topic_validator_futures)
|
||||||
return all(results)
|
if not all(results):
|
||||||
else:
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
return True
|
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -414,8 +414,10 @@ class Pubsub:
|
||||||
return
|
return
|
||||||
# Validate the message with registered topic validators.
|
# Validate the message with registered topic validators.
|
||||||
# If the validation failed, return(i.e., don't further process the message).
|
# If the validation failed, return(i.e., don't further process the message).
|
||||||
is_validation_passed = await self.validate_msg(msg_forwarder, msg)
|
try:
|
||||||
if not is_validation_passed:
|
await self.validate_msg(msg_forwarder, msg)
|
||||||
|
except ValidationError:
|
||||||
|
log.debug(f"Topic validation failed for msg={msg}")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._mark_msg_seen(msg)
|
self._mark_msg_seen(msg)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
from libp2p.pubsub.pb import rpc_pb2
|
||||||
from tests.utils import connect
|
from tests.utils import connect
|
||||||
|
@ -191,13 +192,13 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def failed_sync_validator(peer_id, msg):
|
def failed_sync_validator(peer_id, msg):
|
||||||
return False
|
raise ValidationError()
|
||||||
|
|
||||||
async def passed_async_validator(peer_id, msg):
|
async def passed_async_validator(peer_id, msg):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def failed_async_validator(peer_id, msg):
|
async def failed_async_validator(peer_id, msg):
|
||||||
return False
|
raise ValidationError()
|
||||||
|
|
||||||
topic_1 = "TEST_SYNC_VALIDATOR"
|
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||||
topic_2 = "TEST_ASYNC_VALIDATOR"
|
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||||
|
@ -219,12 +220,11 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
|
||||||
seqno=b"\x00" * 8,
|
seqno=b"\x00" * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_validation_passed = await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
|
||||||
|
|
||||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
assert is_validation_passed
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
else:
|
else:
|
||||||
assert not is_validation_passed
|
with pytest.raises(ValidationError):
|
||||||
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
|
||||||
|
|
||||||
class FakeNetStream:
|
class FakeNetStream:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user