Add validate_msg
and test
This commit is contained in:
parent
ec2c566e5a
commit
e1b86904e3
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union, TYPE_CHECKING
|
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from lru import LRU
|
from lru import LRU
|
||||||
|
|
||||||
|
@ -176,15 +176,13 @@ class Pubsub:
|
||||||
if topic in self.topic_validators:
|
if topic in self.topic_validators:
|
||||||
del self.topic_validators[topic]
|
del self.topic_validators[topic]
|
||||||
|
|
||||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
def get_msg_validators(self, msg: rpc_pb2.Message) -> Iterable[TopicValidator]:
|
||||||
"""
|
"""
|
||||||
Get all validators corresponding to the topics in the message.
|
Get all validators corresponding to the topics in the message.
|
||||||
"""
|
"""
|
||||||
return (
|
for topic in msg.topicIDs:
|
||||||
self.topic_validators[topic]
|
if topic in self.topic_validators:
|
||||||
for topic in msg.topicIDs
|
yield self.topic_validators[topic]
|
||||||
if topic in self.topic_validators
|
|
||||||
)
|
|
||||||
|
|
||||||
async def stream_handler(self, stream: INetStream) -> None:
|
async def stream_handler(self, stream: INetStream) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -357,6 +355,26 @@ class Pubsub:
|
||||||
|
|
||||||
await self.push_msg(self.host.get_id(), msg)
|
await self.push_msg(self.host.get_id(), msg)
|
||||||
|
|
||||||
|
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> bool:
|
||||||
|
sync_topic_validators = []
|
||||||
|
async_topic_validator_futures = []
|
||||||
|
for topic_validator in self.get_msg_validators(msg):
|
||||||
|
if topic_validator.is_async:
|
||||||
|
async_topic_validator_futures.append(
|
||||||
|
topic_validator.validator(msg_forwarder, msg)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sync_topic_validators.append(topic_validator.validator)
|
||||||
|
|
||||||
|
for validator in sync_topic_validators:
|
||||||
|
if not validator(msg_forwarder, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO: Implement throttle on async validators
|
||||||
|
|
||||||
|
results = await asyncio.gather(*async_topic_validator_futures)
|
||||||
|
return all(results)
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
Push a pubsub message to others.
|
Push a pubsub message to others.
|
||||||
|
|
9
libp2p/pubsub/validators.py
Normal file
9
libp2p/pubsub/validators.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
|
||||||
|
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Verify the message against the given public key.
|
||||||
|
:param pubkey: the public key which signs the message.
|
||||||
|
:param msg: the message signed.
|
||||||
|
"""
|
||||||
|
# TODO: Implement the signature validation
|
||||||
|
return True
|
|
@ -208,7 +208,7 @@ async def perform_test_from_obj(obj, router_factory):
|
||||||
tasks_topic.append(asyncio.sleep(2))
|
tasks_topic.append(asyncio.sleep(2))
|
||||||
|
|
||||||
# Gather is like Promise.all
|
# Gather is like Promise.all
|
||||||
responses = await asyncio.gather(*tasks_topic, return_exceptions=True)
|
responses = await asyncio.gather(*tasks_topic)
|
||||||
for i in range(len(responses) - 1):
|
for i in range(len(responses) - 1):
|
||||||
node_id, topic = tasks_topic_data[i]
|
node_id, topic = tasks_topic_data[i]
|
||||||
if node_id not in queues_map:
|
if node_id not in queues_map:
|
||||||
|
|
|
@ -160,14 +160,9 @@ async def test_get_msg_validators(pubsubs_fsub):
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||||
|
|
||||||
assert topic_1 in pubsubs_fsub[0].topic_validators
|
|
||||||
assert topic_2 in pubsubs_fsub[0].topic_validators
|
|
||||||
|
|
||||||
# Register async validator for topic 3
|
# Register async validator for topic 3
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||||
|
|
||||||
assert topic_3 in pubsubs_fsub[0].topic_validators
|
|
||||||
|
|
||||||
msg = make_pubsub_msg(
|
msg = make_pubsub_msg(
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
topic_ids=[topic_1, topic_2, topic_3],
|
topic_ids=[topic_1, topic_2, topic_3],
|
||||||
|
@ -186,6 +181,58 @@ async def test_get_msg_validators(pubsubs_fsub):
|
||||||
assert times_async_validator_called == 1
|
assert times_async_validator_called == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||||||
|
(
|
||||||
|
(False, True),
|
||||||
|
(True, False),
|
||||||
|
(True, True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
|
||||||
|
|
||||||
|
def passed_sync_validator(peer_id, msg):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def failed_sync_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def passed_async_validator(peer_id, msg):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def failed_async_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||||
|
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||||
|
|
||||||
|
if is_topic_1_val_passed:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||||||
|
else:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||||||
|
|
||||||
|
if is_topic_2_val_passed:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||||||
|
else:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||||||
|
|
||||||
|
msg = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[topic_1, topic_2],
|
||||||
|
data=b"1234",
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
is_validation_passed = await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
|
||||||
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
|
assert is_validation_passed
|
||||||
|
else:
|
||||||
|
assert not is_validation_passed
|
||||||
|
|
||||||
|
|
||||||
class FakeNetStream:
|
class FakeNetStream:
|
||||||
_queue: asyncio.Queue
|
_queue: asyncio.Queue
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user