Merge pull request #362 from NIC619/add_signing_and_verification_to_pubsub
Add signing and verification to pubsub
This commit is contained in:
commit
dfdcf524b7
|
@ -16,6 +16,7 @@ from typing import (
|
|||
import base58
|
||||
from lru import LRU
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
|
@ -28,7 +29,7 @@ from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
|
|||
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .validators import signature_validator
|
||||
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pubsub_router_interface import IPubsubRouter # noqa: F401
|
||||
|
@ -82,8 +83,17 @@ class Pubsub:
|
|||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
||||
self,
|
||||
host: IHost,
|
||||
router: "IPubsubRouter",
|
||||
my_id: ID,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
|
@ -147,6 +157,12 @@ class Pubsub:
|
|||
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
|
||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||
"""Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics."""
|
||||
|
@ -456,7 +472,13 @@ class Pubsub:
|
|||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
if self.strict_signing:
|
||||
priv_key = self.sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
msg.key = self.host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
|
||||
await self.push_msg(self.host.get_id(), msg)
|
||||
|
||||
|
@ -505,18 +527,17 @@ class Pubsub:
|
|||
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: Check if signing is required and if so signature should be attached.
|
||||
|
||||
# If the message is processed before, return(i.e., don't further process the message).
|
||||
if self._is_msg_seen(msg):
|
||||
return
|
||||
|
||||
# TODO: - Validate the message. If failed, reject it.
|
||||
# Validate the signature of the message
|
||||
# FIXME: `signature_validator` is currently a stub.
|
||||
if not signature_validator(msg.key, msg.SerializeToString()):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
# Check if signing is required and if so validate the signature
|
||||
if self.strict_signing:
|
||||
# Validate the signature of the message
|
||||
if not signature_validator(msg):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
|
||||
# Validate the message with registered topic validators.
|
||||
# If the validation failed, return(i.e., don't further process the message).
|
||||
try:
|
||||
|
|
|
@ -1,10 +1,41 @@
|
|||
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
|
||||
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.serialization import deserialize_public_key
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .pb import rpc_pb2
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub")
|
||||
|
||||
PUBSUB_SIGNING_PREFIX = "libp2p-pubsub:"
|
||||
|
||||
|
||||
def signature_validator(msg: rpc_pb2.Message) -> bool:
|
||||
"""
|
||||
Verify the message against the given public key.
|
||||
|
||||
:param pubkey: the public key which signs the message.
|
||||
:param msg: the message signed.
|
||||
"""
|
||||
# TODO: Implement the signature validation
|
||||
return True
|
||||
# Check if signature is attached
|
||||
if msg.signature == b"":
|
||||
logger.debug("Reject because no signature attached for msg: %s", msg)
|
||||
return False
|
||||
|
||||
# Validate if message sender matches message signer,
|
||||
# i.e., check if `msg.key` matches `msg.from_id`
|
||||
msg_pubkey = deserialize_public_key(msg.key)
|
||||
if ID.from_pubkey(msg_pubkey) != msg.from_id:
|
||||
logger.debug(
|
||||
"Reject because signing key does not match sender ID for msg: %s", msg
|
||||
)
|
||||
return False
|
||||
# First, construct the original payload that's signed by 'msg.key'
|
||||
msg_without_key_sig = rpc_pb2.Message(
|
||||
data=msg.data, topicIDs=msg.topicIDs, from_id=msg.from_id, seqno=msg.seqno
|
||||
)
|
||||
payload = PUBSUB_SIGNING_PREFIX.encode() + msg_without_key_sig.SerializeToString()
|
||||
try:
|
||||
return msg_pubkey.verify(payload, msg.signature)
|
||||
except Exception:
|
||||
return False
|
||||
|
|
|
@ -153,6 +153,7 @@ class PubsubFactory(factory.Factory):
|
|||
router = None
|
||||
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
||||
cache_size = None
|
||||
strict_signing = False
|
||||
|
||||
|
||||
async def swarm_pair_factory(
|
||||
|
|
|
@ -4,14 +4,24 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
|||
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
|
||||
|
||||
|
||||
def _make_pubsubs(hosts, pubsub_routers, cache_size):
|
||||
@pytest.fixture
|
||||
def is_strict_signing():
|
||||
return False
|
||||
|
||||
|
||||
def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing):
|
||||
if len(pubsub_routers) != len(hosts):
|
||||
raise ValueError(
|
||||
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
|
||||
f"length of hosts={len(hosts)}"
|
||||
)
|
||||
return tuple(
|
||||
PubsubFactory(host=host, router=router, cache_size=cache_size)
|
||||
PubsubFactory(
|
||||
host=host,
|
||||
router=router,
|
||||
cache_size=cache_size,
|
||||
strict_signing=is_strict_signing,
|
||||
)
|
||||
for host, router in zip(hosts, pubsub_routers)
|
||||
)
|
||||
|
||||
|
@ -27,16 +37,22 @@ def gossipsub_params():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size):
|
||||
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing):
|
||||
floodsubs = FloodsubFactory.create_batch(num_hosts)
|
||||
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size)
|
||||
_pubsubs_fsub = _make_pubsubs(
|
||||
hosts, floodsubs, pubsub_cache_size, is_strict_signing
|
||||
)
|
||||
yield _pubsubs_fsub
|
||||
# TODO: Clean up
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
|
||||
def pubsubs_gsub(
|
||||
num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing
|
||||
):
|
||||
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
||||
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
|
||||
_pubsubs_gsub = _make_pubsubs(
|
||||
hosts, gossipsubs, pubsub_cache_size, is_strict_signing
|
||||
)
|
||||
yield _pubsubs_gsub
|
||||
# TODO: Clean up
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from libp2p.exceptions import ValidationError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.pubsub import PUBSUB_SIGNING_PREFIX
|
||||
from libp2p.tools.pubsub.utils import make_pubsub_msg
|
||||
from libp2p.tools.utils import connect
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
@ -510,3 +511,70 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
|
|||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||
await asyncio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts, is_strict_signing", ((2, True),))
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_signing(pubsubs_fsub, hosts):
|
||||
await connect(hosts[0], hosts[1])
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
await pubsubs_fsub[1].subscribe(TESTING_TOPIC)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert len(pubsubs_fsub[0].seen_messages) == 1
|
||||
assert len(pubsubs_fsub[1].seen_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts, is_strict_signing", ((2, True),))
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_signing_failed_validation(pubsubs_fsub, hosts, monkeypatch):
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
priv_key = pubsubs_fsub[0].sign_key
|
||||
signature = priv_key.sign(PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString())
|
||||
|
||||
event = asyncio.Event()
|
||||
|
||||
def _is_msg_seen(msg):
|
||||
return False
|
||||
|
||||
# Use router publish to check if `push_msg` succeed.
|
||||
async def router_publish(*args, **kwargs):
|
||||
# The event will only be set if `push_msg` succeed.
|
||||
event.set()
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "_is_msg_seen", _is_msg_seen)
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
|
||||
# Test: no signature attached in `msg`
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await asyncio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: `msg.key` does not match `msg.from_id`
|
||||
msg.key = hosts[1].get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await asyncio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Test: invalid signature
|
||||
msg.key = hosts[0].get_public_key().serialize()
|
||||
msg.signature = b"\x12" * 100
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await asyncio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
# Finally, assert the signature indeed will pass validation
|
||||
msg.key = hosts[0].get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await asyncio.sleep(0.01)
|
||||
assert event.is_set()
|
||||
|
|
|
@ -76,7 +76,24 @@ def is_gossipsub():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory):
|
||||
def is_pubsub_signing():
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_pubsub_signing_strict():
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def p2pds(
|
||||
num_p2pds,
|
||||
is_host_secure,
|
||||
is_gossipsub,
|
||||
unused_tcp_port_factory,
|
||||
is_pubsub_signing,
|
||||
is_pubsub_signing_strict,
|
||||
):
|
||||
p2pds: Union[Daemon, Exception] = await asyncio.gather(
|
||||
*[
|
||||
make_p2pd(
|
||||
|
@ -84,6 +101,8 @@ async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory
|
|||
unused_tcp_port_factory(),
|
||||
is_host_secure,
|
||||
is_gossipsub=is_gossipsub,
|
||||
is_pubsub_signing=is_pubsub_signing,
|
||||
is_pubsub_signing_strict=is_pubsub_signing_strict,
|
||||
)
|
||||
for _ in range(num_p2pds)
|
||||
],
|
||||
|
@ -102,13 +121,14 @@ async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs(num_hosts, hosts, is_gossipsub):
|
||||
def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict):
|
||||
if is_gossipsub:
|
||||
routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict())
|
||||
else:
|
||||
routers = FloodsubFactory.create_batch(num_hosts)
|
||||
_pubsubs = tuple(
|
||||
PubsubFactory(host=host, router=router) for host, router in zip(hosts, routers)
|
||||
PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict)
|
||||
for host, router in zip(hosts, routers)
|
||||
)
|
||||
yield _pubsubs
|
||||
# TODO: Clean up
|
||||
|
|
|
@ -55,6 +55,9 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) ->
|
|||
assert msg.data == data and msg.from_id == from_peer_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False))
|
||||
)
|
||||
@pytest.mark.parametrize("is_gossipsub", (True, False))
|
||||
@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),))
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Reference in New Issue
Block a user