Merge pull request #226 from NIC619/add_msg_validator
Add topic message validator
This commit is contained in:
commit
0d709364f8
6
libp2p/exceptions.py
Normal file
6
libp2p/exceptions.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""
|
||||||
|
Raised when something does not pass a validation check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
|
@ -1,25 +1,39 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTuple, Tuple, Union
|
||||||
|
|
||||||
from lru import LRU
|
from lru import LRU
|
||||||
|
|
||||||
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.host.host_interface import IHost
|
from libp2p.host.host_interface import IHost
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
|
|
||||||
from .pb import rpc_pb2
|
from .pb import rpc_pb2
|
||||||
from .pubsub_notifee import PubsubNotifee
|
from .pubsub_notifee import PubsubNotifee
|
||||||
|
from .validators import signature_validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .pubsub_router_interface import IPubsubRouter
|
from .pubsub_router_interface import IPubsubRouter
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
|
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
|
||||||
# NOTE: `string(from, seqno)` in Go
|
# NOTE: `string(from, seqno)` in Go
|
||||||
return (msg.seqno, msg.from_id)
|
return (msg.seqno, msg.from_id)
|
||||||
|
|
||||||
|
|
||||||
|
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||||
|
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||||
|
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||||
|
|
||||||
|
|
||||||
|
TopicValidator = NamedTuple("TopicValidator", (("validator", ValidatorFn), ("is_async", bool)))
|
||||||
|
|
||||||
|
|
||||||
class Pubsub:
|
class Pubsub:
|
||||||
|
|
||||||
host: IHost
|
host: IHost
|
||||||
|
@ -41,6 +55,8 @@ class Pubsub:
|
||||||
peer_topics: Dict[str, List[ID]]
|
peer_topics: Dict[str, List[ID]]
|
||||||
peers: Dict[ID, INetStream]
|
peers: Dict[ID, INetStream]
|
||||||
|
|
||||||
|
topic_validators: Dict[str, TopicValidator]
|
||||||
|
|
||||||
# NOTE: Be sure it is increased atomically everytime.
|
# NOTE: Be sure it is increased atomically everytime.
|
||||||
counter: int # uint64
|
counter: int # uint64
|
||||||
|
|
||||||
|
@ -93,6 +109,9 @@ class Pubsub:
|
||||||
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
# Create peers map, which maps peer_id (as string) to stream (to a given peer)
|
||||||
self.peers = {}
|
self.peers = {}
|
||||||
|
|
||||||
|
# Map of topic to topic validator
|
||||||
|
self.topic_validators = {}
|
||||||
|
|
||||||
self.counter = time.time_ns()
|
self.counter = time.time_ns()
|
||||||
|
|
||||||
# Call handle peer to keep waiting for updates to peer queue
|
# Call handle peer to keep waiting for updates to peer queue
|
||||||
|
@ -128,7 +147,7 @@ class Pubsub:
|
||||||
continue
|
continue
|
||||||
# TODO(mhchia): This will block this read_stream loop until all data are pushed.
|
# TODO(mhchia): This will block this read_stream loop until all data are pushed.
|
||||||
# Should investigate further if this is an issue.
|
# Should investigate further if this is an issue.
|
||||||
await self.push_msg(msg_forwarder=peer_id, msg=msg)
|
asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg))
|
||||||
|
|
||||||
if rpc_incoming.subscriptions:
|
if rpc_incoming.subscriptions:
|
||||||
# deal with RPC.subscriptions
|
# deal with RPC.subscriptions
|
||||||
|
@ -149,6 +168,34 @@ class Pubsub:
|
||||||
# Force context switch
|
# Force context switch
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def set_topic_validator(
|
||||||
|
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a validator under the given topic. One topic can only have one validtor.
|
||||||
|
:param topic: the topic to register validator under
|
||||||
|
:param validator: the validator used to validate messages published to the topic
|
||||||
|
:param is_async_validator: indicate if the validator is an asynchronous validator
|
||||||
|
"""
|
||||||
|
self.topic_validators[topic] = TopicValidator(validator, is_async_validator)
|
||||||
|
|
||||||
|
def remove_topic_validator(self, topic: str) -> None:
|
||||||
|
"""
|
||||||
|
Remove the validator from the given topic.
|
||||||
|
:param topic: the topic to remove validator from
|
||||||
|
"""
|
||||||
|
if topic in self.topic_validators:
|
||||||
|
del self.topic_validators[topic]
|
||||||
|
|
||||||
|
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||||
|
"""
|
||||||
|
Get all validators corresponding to the topics in the message.
|
||||||
|
:param msg: the message published to the topic
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators
|
||||||
|
)
|
||||||
|
|
||||||
async def stream_handler(self, stream: INetStream) -> None:
|
async def stream_handler(self, stream: INetStream) -> None:
|
||||||
"""
|
"""
|
||||||
Stream handler for pubsub. Gets invoked whenever a new stream is created
|
Stream handler for pubsub. Gets invoked whenever a new stream is created
|
||||||
|
@ -320,6 +367,31 @@ class Pubsub:
|
||||||
|
|
||||||
await self.push_msg(self.host.get_id(), msg)
|
await self.push_msg(self.host.get_id(), msg)
|
||||||
|
|
||||||
|
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
|
"""
|
||||||
|
Validate the received message
|
||||||
|
:param msg_forwarder: the peer who forward us the message.
|
||||||
|
:param msg: the message.
|
||||||
|
"""
|
||||||
|
sync_topic_validators = []
|
||||||
|
async_topic_validator_futures = []
|
||||||
|
for topic_validator in self.get_msg_validators(msg):
|
||||||
|
if topic_validator.is_async:
|
||||||
|
async_topic_validator_futures.append(topic_validator.validator(msg_forwarder, msg))
|
||||||
|
else:
|
||||||
|
sync_topic_validators.append(topic_validator.validator)
|
||||||
|
|
||||||
|
for validator in sync_topic_validators:
|
||||||
|
if not validator(msg_forwarder, msg):
|
||||||
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
|
# TODO: Implement throttle on async validators
|
||||||
|
|
||||||
|
if len(async_topic_validator_futures) > 0:
|
||||||
|
results = await asyncio.gather(*async_topic_validator_futures)
|
||||||
|
if not all(results):
|
||||||
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
Push a pubsub message to others.
|
Push a pubsub message to others.
|
||||||
|
@ -332,10 +404,23 @@ class Pubsub:
|
||||||
|
|
||||||
# TODO: Check if signing is required and if so signature should be attached.
|
# TODO: Check if signing is required and if so signature should be attached.
|
||||||
|
|
||||||
|
# If the message is processed before, return(i.e., don't further process the message).
|
||||||
if self._is_msg_seen(msg):
|
if self._is_msg_seen(msg):
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: - Validate the message. If failed, reject it.
|
# TODO: - Validate the message. If failed, reject it.
|
||||||
|
# Validate the signature of the message
|
||||||
|
# FIXME: `signature_validator` is currently a stub.
|
||||||
|
if not signature_validator(msg.key, msg.SerializeToString()):
|
||||||
|
log.debug(f"Signature validation failed for msg={msg}")
|
||||||
|
return
|
||||||
|
# Validate the message with registered topic validators.
|
||||||
|
# If the validation failed, return(i.e., don't further process the message).
|
||||||
|
try:
|
||||||
|
await self.validate_msg(msg_forwarder, msg)
|
||||||
|
except ValidationError:
|
||||||
|
log.debug(f"Topic validation failed for msg={msg}")
|
||||||
|
return
|
||||||
|
|
||||||
self._mark_msg_seen(msg)
|
self._mark_msg_seen(msg)
|
||||||
await self.handle_talk(msg)
|
await self.handle_talk(msg)
|
||||||
|
@ -361,4 +446,4 @@ class Pubsub:
|
||||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||||
if not self.my_topics:
|
if not self.my_topics:
|
||||||
return False
|
return False
|
||||||
return all([topic in self.my_topics for topic in msg.topicIDs])
|
return any(topic in self.my_topics for topic in msg.topicIDs)
|
||||||
|
|
9
libp2p/pubsub/validators.py
Normal file
9
libp2p/pubsub/validators.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
|
||||||
|
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Verify the message against the given public key.
|
||||||
|
:param pubkey: the public key which signs the message.
|
||||||
|
:param msg: the message signed.
|
||||||
|
"""
|
||||||
|
# TODO: Implement the signature validation
|
||||||
|
return True
|
|
@ -208,7 +208,7 @@ async def perform_test_from_obj(obj, router_factory):
|
||||||
tasks_topic.append(asyncio.sleep(2))
|
tasks_topic.append(asyncio.sleep(2))
|
||||||
|
|
||||||
# Gather is like Promise.all
|
# Gather is like Promise.all
|
||||||
responses = await asyncio.gather(*tasks_topic, return_exceptions=True)
|
responses = await asyncio.gather(*tasks_topic)
|
||||||
for i in range(len(responses) - 1):
|
for i in range(len(responses) - 1):
|
||||||
node_id, topic = tasks_topic_data[i]
|
node_id, topic = tasks_topic_data[i]
|
||||||
if node_id not in queues_map:
|
if node_id not in queues_map:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
from libp2p.pubsub.pb import rpc_pb2
|
||||||
from tests.utils import connect
|
from tests.utils import connect
|
||||||
|
@ -84,6 +85,148 @@ async def test_get_hello_packet(pubsubs_fsub):
|
||||||
assert topic in topic_ids_in_hello
|
assert topic in topic_ids_in_hello
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
||||||
|
|
||||||
|
is_sync_validator_called = False
|
||||||
|
|
||||||
|
def sync_validator(peer_id, msg):
|
||||||
|
nonlocal is_sync_validator_called
|
||||||
|
is_sync_validator_called = True
|
||||||
|
|
||||||
|
is_async_validator_called = False
|
||||||
|
|
||||||
|
async def async_validator(peer_id, msg):
|
||||||
|
nonlocal is_async_validator_called
|
||||||
|
is_async_validator_called = True
|
||||||
|
|
||||||
|
topic = "TEST_VALIDATOR"
|
||||||
|
|
||||||
|
assert topic not in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
|
# Register sync validator
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||||||
|
|
||||||
|
assert topic in pubsubs_fsub[0].topic_validators
|
||||||
|
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||||
|
assert not topic_validator.is_async
|
||||||
|
|
||||||
|
# Validate with sync validator
|
||||||
|
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
|
||||||
|
assert is_sync_validator_called
|
||||||
|
assert not is_async_validator_called
|
||||||
|
|
||||||
|
# Register with async validator
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||||||
|
|
||||||
|
is_sync_validator_called = False
|
||||||
|
assert topic in pubsubs_fsub[0].topic_validators
|
||||||
|
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||||
|
assert topic_validator.is_async
|
||||||
|
|
||||||
|
# Validate with async validator
|
||||||
|
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
|
||||||
|
assert is_async_validator_called
|
||||||
|
assert not is_sync_validator_called
|
||||||
|
|
||||||
|
# Remove validator
|
||||||
|
pubsubs_fsub[0].remove_topic_validator(topic)
|
||||||
|
assert topic not in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_msg_validators(pubsubs_fsub):
|
||||||
|
|
||||||
|
times_sync_validator_called = 0
|
||||||
|
|
||||||
|
def sync_validator(peer_id, msg):
|
||||||
|
nonlocal times_sync_validator_called
|
||||||
|
times_sync_validator_called += 1
|
||||||
|
|
||||||
|
times_async_validator_called = 0
|
||||||
|
|
||||||
|
async def async_validator(peer_id, msg):
|
||||||
|
nonlocal times_async_validator_called
|
||||||
|
times_async_validator_called += 1
|
||||||
|
|
||||||
|
topic_1 = "TEST_VALIDATOR_1"
|
||||||
|
topic_2 = "TEST_VALIDATOR_2"
|
||||||
|
topic_3 = "TEST_VALIDATOR_3"
|
||||||
|
|
||||||
|
# Register sync validator for topic 1 and 2
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||||
|
|
||||||
|
# Register async validator for topic 3
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||||
|
|
||||||
|
msg = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[topic_1, topic_2, topic_3],
|
||||||
|
data=b"1234",
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||||
|
for topic_validator in topic_validators:
|
||||||
|
if topic_validator.is_async:
|
||||||
|
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
else:
|
||||||
|
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||||
|
|
||||||
|
assert times_sync_validator_called == 2
|
||||||
|
assert times_async_validator_called == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_hosts", (1,))
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True))
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
|
||||||
|
def passed_sync_validator(peer_id, msg):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def failed_sync_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def passed_async_validator(peer_id, msg):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def failed_async_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||||
|
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||||
|
|
||||||
|
if is_topic_1_val_passed:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||||||
|
else:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||||||
|
|
||||||
|
if is_topic_2_val_passed:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||||||
|
else:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||||||
|
|
||||||
|
msg = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[topic_1, topic_2],
|
||||||
|
data=b"1234",
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
|
|
||||||
|
|
||||||
class FakeNetStream:
|
class FakeNetStream:
|
||||||
_queue: asyncio.Queue
|
_queue: asyncio.Queue
|
||||||
|
|
||||||
|
@ -319,3 +462,23 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
|
||||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||||
# Test: Subscribers are notified when `push_msg` new messages.
|
# Test: Subscribers are notified when `push_msg` new messages.
|
||||||
assert (await sub.get()) == msg_1
|
assert (await sub.get()) == msg_1
|
||||||
|
|
||||||
|
# Test: add a topic validator and `push_msg` the message that
|
||||||
|
# does not pass the validation.
|
||||||
|
# `router_publish` is not called then.
|
||||||
|
def failed_sync_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False)
|
||||||
|
|
||||||
|
msg_2 = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[TESTING_TOPIC],
|
||||||
|
data=TESTING_DATA,
|
||||||
|
seqno=b"\x22" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
event.clear()
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert not event.is_set()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user