diff --git a/libp2p/security/noise/exceptions.py b/libp2p/security/noise/exceptions.py index 5e6040e..85cf3f2 100644 --- a/libp2p/security/noise/exceptions.py +++ b/libp2p/security/noise/exceptions.py @@ -7,3 +7,16 @@ class NoiseFailure(HandshakeFailure): class HandshakeHasNotFinished(NoiseFailure): pass + + +class InvalidSignature(NoiseFailure): + pass + + +class NoiseStateError(NoiseFailure): + """Raised when anything goes wrong in the noise state in `noiseprotocol` + package.""" + + +class PeerIDMismatchesPubkey(NoiseFailure): + pass diff --git a/libp2p/security/noise/messages.py b/libp2p/security/noise/messages.py new file mode 100644 index 0000000..feb2766 --- /dev/null +++ b/libp2p/security/noise/messages.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass + +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.serialization import deserialize_public_key + +from .pb import noise_pb2 as noise_pb + +SIGNED_DATA_PREFIX = "noise-libp2p-static-key:" + + +@dataclass +class NoiseHandshakePayload: + id_pubkey: PublicKey + id_sig: bytes + early_data: bytes = None + + def serialize(self) -> bytes: + msg = noise_pb.NoiseHandshakePayload( + identity_key=self.id_pubkey.serialize(), identity_sig=self.id_sig + ) + if self.early_data is not None: + msg.data = self.early_data + return msg.SerializeToString() + + @classmethod + def deserialize(cls, protobuf_bytes: bytes) -> "NoiseHandshakePayload": + msg = noise_pb.NoiseHandshakePayload.FromString(protobuf_bytes) + return cls( + id_pubkey=deserialize_public_key(msg.identity_key), + id_sig=msg.identity_sig, + early_data=msg.data if msg.data != b"" else None, + ) + + +def make_data_to_be_signed(noise_static_pubkey: PublicKey) -> bytes: + prefix_bytes = SIGNED_DATA_PREFIX.encode("utf-8") + return prefix_bytes + noise_static_pubkey.to_bytes() + + +def make_handshake_payload_sig( + id_privkey: PrivateKey, noise_static_pubkey: PublicKey +) -> bytes: + data = make_data_to_be_signed(noise_static_pubkey) + return id_privkey.sign(data) + + +def verify_handshake_payload_sig( + payload: NoiseHandshakePayload, noise_static_pubkey: PublicKey +) -> bool: + """ + Verify if the signature + 1. is composed of the data `SIGNED_DATA_PREFIX`++`noise_static_pubkey` and + 2. signed by the private key corresponding to `id_pubkey` + """ + expected_data = make_data_to_be_signed(noise_static_pubkey) + return payload.id_pubkey.verify(expected_data, payload.id_sig) diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index a05e807..4590b73 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -3,14 +3,25 @@ from abc import ABC, abstractmethod from noise.connection import Keypair as NoiseKeypair from noise.connection import NoiseConnection as NoiseState +from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.keys import PrivateKey from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from .connection import NoiseConnection -from .exceptions import HandshakeHasNotFinished +from .exceptions import ( + HandshakeHasNotFinished, + InvalidSignature, + NoiseStateError, + PeerIDMismatchesPubkey, +) from .io import NoiseHandshakeReadWriter +from .messages import ( + NoiseHandshakePayload, + make_handshake_payload_sig, + verify_handshake_payload_sig, +) class IPattern(ABC): @@ -30,6 +41,7 @@ class BasePattern(IPattern): noise_static_key: PrivateKey local_peer: ID libp2p_privkey: PrivateKey + early_data: bytes def create_noise_state(self) -> NoiseState: noise_state = NoiseState.from_name(self.protocol_name) @@ -38,59 +50,102 @@ class BasePattern(IPattern): ) return noise_state + def make_handshake_payload(self) -> NoiseHandshakePayload: + signature = make_handshake_payload_sig( + self.libp2p_privkey, self.noise_static_key.get_public_key() + ) + return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature) + class PatternXX(BasePattern): def __init__( - self, local_peer: ID, libp2p_privkey: PrivateKey, noise_static_key: PrivateKey + self, + local_peer: ID, + libp2p_privkey: PrivateKey, + noise_static_key: PrivateKey, + early_data: bytes = None, ) -> None: self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256" self.local_peer = local_peer self.libp2p_privkey = libp2p_privkey self.noise_static_key = noise_static_key + self.early_data = early_data async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: noise_state = self.create_noise_state() noise_state.set_as_responder() noise_state.start_handshake() + state = noise_state.noise_protocol.handshake_state read_writer = NoiseHandshakeReadWriter(conn, noise_state) - # TODO: Parse and save the payload from the other side. - _ = await read_writer.read_msg() - # TODO: Send our payload. - our_payload = b"server" - await read_writer.write_msg(our_payload) + # Consume msg#1 + await read_writer.read_msg() - # TODO: Parse and save another payload from the other side. - _ = await read_writer.read_msg() + # Send msg#2, which should include our handshake payload. + our_payload = self.make_handshake_payload() + msg_2 = our_payload.serialize() + await read_writer.write_msg(msg_2) + + # Receive msg#3 + msg_3 = await read_writer.read_msg() + peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3) + + if state.rs is None: + raise NoiseStateError + remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes) + if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey): + raise InvalidSignature + remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey) - # TODO: Add a specific exception if not noise_state.handshake_finished: raise HandshakeHasNotFinished( - "handshake done but it is not marked as finished in `noise_state`" + "handshake is done but it is not marked as finished in `noise_state`" ) - # FIXME: `remote_peer` should be derived from the messages. - return NoiseConnection(self.local_peer, self.libp2p_privkey, None, conn, False) + return NoiseConnection( + self.local_peer, + self.libp2p_privkey, + remote_peer_id_from_pubkey, + conn, + False, + ) async def handshake_outbound( self, conn: IRawConnection, remote_peer: ID ) -> ISecureConn: noise_state = self.create_noise_state() + read_writer = NoiseHandshakeReadWriter(conn, noise_state) noise_state.set_as_initiator() noise_state.start_handshake() - await read_writer.write_msg(b"") + state = noise_state.noise_protocol.handshake_state - # TODO: Parse and save the payload from the other side. - _ = await read_writer.read_msg() + msg_1 = b"" + await read_writer.write_msg(msg_1) - # TODO: Send our payload. - our_payload = b"client" - await read_writer.write_msg(our_payload) + msg_2 = await read_writer.read_msg() + peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2) + if state.rs is None: + raise NoiseStateError + remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes) + if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey): + raise InvalidSignature + remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey) + if remote_peer_id_from_pubkey != remote_peer: + raise PeerIDMismatchesPubkey( + "peer id does not correspond to the received pubkey: " + f"remote_peer={remote_peer}, " + f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}" + ) + + our_payload = self.make_handshake_payload() + msg_3 = our_payload.serialize() + await read_writer.write_msg(msg_3) - # TODO: Add a specific exception if not noise_state.handshake_finished: - raise Exception + raise HandshakeHasNotFinished( + "handshake is done but it is not marked as finished in `noise_state`" + ) return NoiseConnection( self.local_peer, self.libp2p_privkey, remote_peer, conn, False diff --git a/libp2p/security/noise/transport.py b/libp2p/security/noise/transport.py index 37e15e8..5ee3ba5 100644 --- a/libp2p/security/noise/transport.py +++ b/libp2p/security/noise/transport.py @@ -38,7 +38,12 @@ class Transport(ISecureTransport): if self.with_noise_pipes: raise NotImplementedError else: - return PatternXX(self.local_peer, self.libp2p_privkey, self.noise_privkey) + return PatternXX( + self.local_peer, + self.libp2p_privkey, + self.noise_privkey, + self.early_data, + ) async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: # TODO: SecureInbound attempts to complete a noise-libp2p handshake initiated diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 75a37a3..7fad836 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -29,6 +29,10 @@ from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.routing.interfaces import IPeerRouting from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.security.noise.messages import ( + NoiseHandshakePayload, + make_handshake_payload_sig, +) from libp2p.security.noise.transport import Transport as NoiseTransport import libp2p.security.secio.transport as secio from libp2p.security.secure_conn_interface import ISecureConn @@ -73,6 +77,17 @@ def noise_static_key_factory() -> PrivateKey: return create_ed25519_key_pair().private_key +def noise_handshake_payload_factory() -> NoiseHandshakePayload: + libp2p_keypair = create_secp256k1_key_pair() + noise_static_privkey = noise_static_key_factory() + return NoiseHandshakePayload( + libp2p_keypair.public_key, + make_handshake_payload_sig( + libp2p_keypair.private_key, noise_static_privkey.get_public_key() + ), + ) + + def noise_transport_factory() -> NoiseTransport: return NoiseTransport( libp2p_keypair=create_secp256k1_key_pair(), @@ -118,7 +133,7 @@ async def noise_conn_factory( async def upgrade_local_conn() -> None: nonlocal local_secure_conn local_secure_conn = await local_transport.secure_outbound( - local_conn, local_transport.local_peer + local_conn, remote_transport.local_peer ) async def upgrade_remote_conn() -> None: diff --git a/tests/security/noise/test_noise.py b/tests/security/noise/test_noise.py index c60f83c..1c5eebb 100644 --- a/tests/security/noise/test_noise.py +++ b/tests/security/noise/test_noise.py @@ -1,6 +1,7 @@ import pytest -from libp2p.tools.factories import noise_conn_factory +from libp2p.security.noise.messages import NoiseHandshakePayload +from libp2p.tools.factories import noise_conn_factory, noise_handshake_payload_factory DATA = b"testing_123" @@ -18,3 +19,10 @@ async def test_noise_connection(nursery): await local_conn.write(DATA) read_data = await remote_conn.read(len(DATA)) assert read_data == DATA + + +def test_noise_handshake_payload(): + payload = noise_handshake_payload_factory() + payload_serialized = payload.serialize() + payload_deserialized = NoiseHandshakePayload.deserialize(payload_serialized) + assert payload == payload_deserialized