Noise: complete handshake process
TODO - Figure out why `state.rs` is erased at some moment(even handshake is not done). - Refactor - Add tests
This commit is contained in:
parent
8a4ebd4cbb
commit
d0290d2b5a
@ -7,3 +7,16 @@ class NoiseFailure(HandshakeFailure):
|
|||||||
|
|
||||||
class HandshakeHasNotFinished(NoiseFailure):
|
class HandshakeHasNotFinished(NoiseFailure):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSignature(NoiseFailure):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseStateError(NoiseFailure):
|
||||||
|
"""Raised when anything goes wrong in the noise state in `noiseprotocol`
|
||||||
|
package."""
|
||||||
|
|
||||||
|
|
||||||
|
class PeerIDMismatchesPubkey(NoiseFailure):
|
||||||
|
pass
|
||||||
|
56
libp2p/security/noise/messages.py
Normal file
56
libp2p/security/noise/messages.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||||
|
from libp2p.crypto.serialization import deserialize_public_key
|
||||||
|
|
||||||
|
from .pb import noise_pb2 as noise_pb
|
||||||
|
|
||||||
|
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NoiseHandshakePayload:
|
||||||
|
id_pubkey: PublicKey
|
||||||
|
id_sig: bytes
|
||||||
|
early_data: bytes = None
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
msg = noise_pb.NoiseHandshakePayload(
|
||||||
|
identity_key=self.id_pubkey.serialize(), identity_sig=self.id_sig
|
||||||
|
)
|
||||||
|
if self.early_data is not None:
|
||||||
|
msg.data = self.early_data
|
||||||
|
return msg.SerializeToString()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, protobuf_bytes: bytes) -> "NoiseHandshakePayload":
|
||||||
|
msg = noise_pb.NoiseHandshakePayload.FromString(protobuf_bytes)
|
||||||
|
return cls(
|
||||||
|
id_pubkey=deserialize_public_key(msg.identity_key),
|
||||||
|
id_sig=msg.identity_sig,
|
||||||
|
early_data=msg.data if msg.data != b"" else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_data_to_be_signed(noise_static_pubkey: PublicKey) -> bytes:
|
||||||
|
prefix_bytes = SIGNED_DATA_PREFIX.encode("utf-8")
|
||||||
|
return prefix_bytes + noise_static_pubkey.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
def make_handshake_payload_sig(
|
||||||
|
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||||
|
) -> bytes:
|
||||||
|
data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
|
return id_privkey.sign(data)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_handshake_payload_sig(
|
||||||
|
payload: NoiseHandshakePayload, noise_static_pubkey: PublicKey
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Verify if the signature
|
||||||
|
1. is composed of the data `SIGNED_DATA_PREFIX`++`noise_static_pubkey` and
|
||||||
|
2. signed by the private key corresponding to `id_pubkey`
|
||||||
|
"""
|
||||||
|
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
|
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
@ -3,14 +3,25 @@ from abc import ABC, abstractmethod
|
|||||||
from noise.connection import Keypair as NoiseKeypair
|
from noise.connection import Keypair as NoiseKeypair
|
||||||
from noise.connection import NoiseConnection as NoiseState
|
from noise.connection import NoiseConnection as NoiseState
|
||||||
|
|
||||||
|
from libp2p.crypto.ed25519 import Ed25519PublicKey
|
||||||
from libp2p.crypto.keys import PrivateKey
|
from libp2p.crypto.keys import PrivateKey
|
||||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
|
||||||
from .connection import NoiseConnection
|
from .connection import NoiseConnection
|
||||||
from .exceptions import HandshakeHasNotFinished
|
from .exceptions import (
|
||||||
|
HandshakeHasNotFinished,
|
||||||
|
InvalidSignature,
|
||||||
|
NoiseStateError,
|
||||||
|
PeerIDMismatchesPubkey,
|
||||||
|
)
|
||||||
from .io import NoiseHandshakeReadWriter
|
from .io import NoiseHandshakeReadWriter
|
||||||
|
from .messages import (
|
||||||
|
NoiseHandshakePayload,
|
||||||
|
make_handshake_payload_sig,
|
||||||
|
verify_handshake_payload_sig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPattern(ABC):
|
class IPattern(ABC):
|
||||||
@ -30,6 +41,7 @@ class BasePattern(IPattern):
|
|||||||
noise_static_key: PrivateKey
|
noise_static_key: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
|
early_data: bytes
|
||||||
|
|
||||||
def create_noise_state(self) -> NoiseState:
|
def create_noise_state(self) -> NoiseState:
|
||||||
noise_state = NoiseState.from_name(self.protocol_name)
|
noise_state = NoiseState.from_name(self.protocol_name)
|
||||||
@ -38,59 +50,102 @@ class BasePattern(IPattern):
|
|||||||
)
|
)
|
||||||
return noise_state
|
return noise_state
|
||||||
|
|
||||||
|
def make_handshake_payload(self) -> NoiseHandshakePayload:
|
||||||
|
signature = make_handshake_payload_sig(
|
||||||
|
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
||||||
|
)
|
||||||
|
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
|
||||||
|
|
||||||
|
|
||||||
class PatternXX(BasePattern):
|
class PatternXX(BasePattern):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, local_peer: ID, libp2p_privkey: PrivateKey, noise_static_key: PrivateKey
|
self,
|
||||||
|
local_peer: ID,
|
||||||
|
libp2p_privkey: PrivateKey,
|
||||||
|
noise_static_key: PrivateKey,
|
||||||
|
early_data: bytes = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.libp2p_privkey = libp2p_privkey
|
self.libp2p_privkey = libp2p_privkey
|
||||||
self.noise_static_key = noise_static_key
|
self.noise_static_key = noise_static_key
|
||||||
|
self.early_data = early_data
|
||||||
|
|
||||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
noise_state.set_as_responder()
|
noise_state.set_as_responder()
|
||||||
noise_state.start_handshake()
|
noise_state.start_handshake()
|
||||||
|
state = noise_state.noise_protocol.handshake_state
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
# TODO: Parse and save the payload from the other side.
|
|
||||||
_ = await read_writer.read_msg()
|
|
||||||
|
|
||||||
# TODO: Send our payload.
|
# Consume msg#1
|
||||||
our_payload = b"server"
|
await read_writer.read_msg()
|
||||||
await read_writer.write_msg(our_payload)
|
|
||||||
|
|
||||||
# TODO: Parse and save another payload from the other side.
|
# Send msg#2, which should include our handshake payload.
|
||||||
_ = await read_writer.read_msg()
|
our_payload = self.make_handshake_payload()
|
||||||
|
msg_2 = our_payload.serialize()
|
||||||
|
await read_writer.write_msg(msg_2)
|
||||||
|
|
||||||
|
# Receive msg#3
|
||||||
|
msg_3 = await read_writer.read_msg()
|
||||||
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||||
|
|
||||||
|
if state.rs is None:
|
||||||
|
raise NoiseStateError
|
||||||
|
remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes)
|
||||||
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
|
raise InvalidSignature
|
||||||
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
|
|
||||||
# TODO: Add a specific exception
|
|
||||||
if not noise_state.handshake_finished:
|
if not noise_state.handshake_finished:
|
||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: `remote_peer` should be derived from the messages.
|
return NoiseConnection(
|
||||||
return NoiseConnection(self.local_peer, self.libp2p_privkey, None, conn, False)
|
self.local_peer,
|
||||||
|
self.libp2p_privkey,
|
||||||
|
remote_peer_id_from_pubkey,
|
||||||
|
conn,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
async def handshake_outbound(
|
async def handshake_outbound(
|
||||||
self, conn: IRawConnection, remote_peer: ID
|
self, conn: IRawConnection, remote_peer: ID
|
||||||
) -> ISecureConn:
|
) -> ISecureConn:
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
noise_state.set_as_initiator()
|
noise_state.set_as_initiator()
|
||||||
noise_state.start_handshake()
|
noise_state.start_handshake()
|
||||||
await read_writer.write_msg(b"")
|
state = noise_state.noise_protocol.handshake_state
|
||||||
|
|
||||||
# TODO: Parse and save the payload from the other side.
|
msg_1 = b""
|
||||||
_ = await read_writer.read_msg()
|
await read_writer.write_msg(msg_1)
|
||||||
|
|
||||||
# TODO: Send our payload.
|
msg_2 = await read_writer.read_msg()
|
||||||
our_payload = b"client"
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||||
await read_writer.write_msg(our_payload)
|
if state.rs is None:
|
||||||
|
raise NoiseStateError
|
||||||
|
remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes)
|
||||||
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
|
raise InvalidSignature
|
||||||
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
|
if remote_peer_id_from_pubkey != remote_peer:
|
||||||
|
raise PeerIDMismatchesPubkey(
|
||||||
|
"peer id does not correspond to the received pubkey: "
|
||||||
|
f"remote_peer={remote_peer}, "
|
||||||
|
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
||||||
|
)
|
||||||
|
|
||||||
|
our_payload = self.make_handshake_payload()
|
||||||
|
msg_3 = our_payload.serialize()
|
||||||
|
await read_writer.write_msg(msg_3)
|
||||||
|
|
||||||
# TODO: Add a specific exception
|
|
||||||
if not noise_state.handshake_finished:
|
if not noise_state.handshake_finished:
|
||||||
raise Exception
|
raise HandshakeHasNotFinished(
|
||||||
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
|
)
|
||||||
|
|
||||||
return NoiseConnection(
|
return NoiseConnection(
|
||||||
self.local_peer, self.libp2p_privkey, remote_peer, conn, False
|
self.local_peer, self.libp2p_privkey, remote_peer, conn, False
|
||||||
|
@ -38,7 +38,12 @@ class Transport(ISecureTransport):
|
|||||||
if self.with_noise_pipes:
|
if self.with_noise_pipes:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
return PatternXX(self.local_peer, self.libp2p_privkey, self.noise_privkey)
|
return PatternXX(
|
||||||
|
self.local_peer,
|
||||||
|
self.libp2p_privkey,
|
||||||
|
self.noise_privkey,
|
||||||
|
self.early_data,
|
||||||
|
)
|
||||||
|
|
||||||
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
# TODO: SecureInbound attempts to complete a noise-libp2p handshake initiated
|
# TODO: SecureInbound attempts to complete a noise-libp2p handshake initiated
|
||||||
|
@ -29,6 +29,10 @@ from libp2p.pubsub.gossipsub import GossipSub
|
|||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
from libp2p.routing.interfaces import IPeerRouting
|
from libp2p.routing.interfaces import IPeerRouting
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
|
from libp2p.security.noise.messages import (
|
||||||
|
NoiseHandshakePayload,
|
||||||
|
make_handshake_payload_sig,
|
||||||
|
)
|
||||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||||
import libp2p.security.secio.transport as secio
|
import libp2p.security.secio.transport as secio
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
@ -73,6 +77,17 @@ def noise_static_key_factory() -> PrivateKey:
|
|||||||
return create_ed25519_key_pair().private_key
|
return create_ed25519_key_pair().private_key
|
||||||
|
|
||||||
|
|
||||||
|
def noise_handshake_payload_factory() -> NoiseHandshakePayload:
|
||||||
|
libp2p_keypair = create_secp256k1_key_pair()
|
||||||
|
noise_static_privkey = noise_static_key_factory()
|
||||||
|
return NoiseHandshakePayload(
|
||||||
|
libp2p_keypair.public_key,
|
||||||
|
make_handshake_payload_sig(
|
||||||
|
libp2p_keypair.private_key, noise_static_privkey.get_public_key()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def noise_transport_factory() -> NoiseTransport:
|
def noise_transport_factory() -> NoiseTransport:
|
||||||
return NoiseTransport(
|
return NoiseTransport(
|
||||||
libp2p_keypair=create_secp256k1_key_pair(),
|
libp2p_keypair=create_secp256k1_key_pair(),
|
||||||
@ -118,7 +133,7 @@ async def noise_conn_factory(
|
|||||||
async def upgrade_local_conn() -> None:
|
async def upgrade_local_conn() -> None:
|
||||||
nonlocal local_secure_conn
|
nonlocal local_secure_conn
|
||||||
local_secure_conn = await local_transport.secure_outbound(
|
local_secure_conn = await local_transport.secure_outbound(
|
||||||
local_conn, local_transport.local_peer
|
local_conn, remote_transport.local_peer
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upgrade_remote_conn() -> None:
|
async def upgrade_remote_conn() -> None:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.tools.factories import noise_conn_factory
|
from libp2p.security.noise.messages import NoiseHandshakePayload
|
||||||
|
from libp2p.tools.factories import noise_conn_factory, noise_handshake_payload_factory
|
||||||
|
|
||||||
DATA = b"testing_123"
|
DATA = b"testing_123"
|
||||||
|
|
||||||
@ -18,3 +19,10 @@ async def test_noise_connection(nursery):
|
|||||||
await local_conn.write(DATA)
|
await local_conn.write(DATA)
|
||||||
read_data = await remote_conn.read(len(DATA))
|
read_data = await remote_conn.read(len(DATA))
|
||||||
assert read_data == DATA
|
assert read_data == DATA
|
||||||
|
|
||||||
|
|
||||||
|
def test_noise_handshake_payload():
|
||||||
|
payload = noise_handshake_payload_factory()
|
||||||
|
payload_serialized = payload.serialize()
|
||||||
|
payload_deserialized = NoiseHandshakePayload.deserialize(payload_serialized)
|
||||||
|
assert payload == payload_deserialized
|
||||||
|
Loading…
x
Reference in New Issue
Block a user