Noise: complete handshake process

TODO
- Figure out why `state.rs` is erased at some moment(even handshake
is not done).
- Refactor
- Add tests
This commit is contained in:
mhchia 2020-02-16 00:42:49 +08:00
parent 8a4ebd4cbb
commit d0290d2b5a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
6 changed files with 176 additions and 24 deletions

View File

@ -7,3 +7,16 @@ class NoiseFailure(HandshakeFailure):
class HandshakeHasNotFinished(NoiseFailure): class HandshakeHasNotFinished(NoiseFailure):
pass pass
class InvalidSignature(NoiseFailure):
pass
class NoiseStateError(NoiseFailure):
"""Raised when anything goes wrong in the noise state in `noiseprotocol`
package."""
class PeerIDMismatchesPubkey(NoiseFailure):
pass

View File

@ -0,0 +1,56 @@
from dataclasses import dataclass
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key
from .pb import noise_pb2 as noise_pb
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@dataclass
class NoiseHandshakePayload:
id_pubkey: PublicKey
id_sig: bytes
early_data: bytes = None
def serialize(self) -> bytes:
msg = noise_pb.NoiseHandshakePayload(
identity_key=self.id_pubkey.serialize(), identity_sig=self.id_sig
)
if self.early_data is not None:
msg.data = self.early_data
return msg.SerializeToString()
@classmethod
def deserialize(cls, protobuf_bytes: bytes) -> "NoiseHandshakePayload":
msg = noise_pb.NoiseHandshakePayload.FromString(protobuf_bytes)
return cls(
id_pubkey=deserialize_public_key(msg.identity_key),
id_sig=msg.identity_sig,
early_data=msg.data if msg.data != b"" else None,
)
def make_data_to_be_signed(noise_static_pubkey: PublicKey) -> bytes:
prefix_bytes = SIGNED_DATA_PREFIX.encode("utf-8")
return prefix_bytes + noise_static_pubkey.to_bytes()
def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey)
return id_privkey.sign(data)
def verify_handshake_payload_sig(
payload: NoiseHandshakePayload, noise_static_pubkey: PublicKey
) -> bool:
"""
Verify if the signature
1. is composed of the data `SIGNED_DATA_PREFIX`++`noise_static_pubkey` and
2. signed by the private key corresponding to `id_pubkey`
"""
expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig)

View File

@ -3,14 +3,25 @@ from abc import ABC, abstractmethod
from noise.connection import Keypair as NoiseKeypair from noise.connection import Keypair as NoiseKeypair
from noise.connection import NoiseConnection as NoiseState from noise.connection import NoiseConnection as NoiseState
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from .connection import NoiseConnection from .connection import NoiseConnection
from .exceptions import HandshakeHasNotFinished from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
NoiseStateError,
PeerIDMismatchesPubkey,
)
from .io import NoiseHandshakeReadWriter from .io import NoiseHandshakeReadWriter
from .messages import (
NoiseHandshakePayload,
make_handshake_payload_sig,
verify_handshake_payload_sig,
)
class IPattern(ABC): class IPattern(ABC):
@ -30,6 +41,7 @@ class BasePattern(IPattern):
noise_static_key: PrivateKey noise_static_key: PrivateKey
local_peer: ID local_peer: ID
libp2p_privkey: PrivateKey libp2p_privkey: PrivateKey
early_data: bytes
def create_noise_state(self) -> NoiseState: def create_noise_state(self) -> NoiseState:
noise_state = NoiseState.from_name(self.protocol_name) noise_state = NoiseState.from_name(self.protocol_name)
@ -38,59 +50,102 @@ class BasePattern(IPattern):
) )
return noise_state return noise_state
def make_handshake_payload(self) -> NoiseHandshakePayload:
signature = make_handshake_payload_sig(
self.libp2p_privkey, self.noise_static_key.get_public_key()
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
class PatternXX(BasePattern): class PatternXX(BasePattern):
def __init__( def __init__(
self, local_peer: ID, libp2p_privkey: PrivateKey, noise_static_key: PrivateKey self,
local_peer: ID,
libp2p_privkey: PrivateKey,
noise_static_key: PrivateKey,
early_data: bytes = None,
) -> None: ) -> None:
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256" self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
self.local_peer = local_peer self.local_peer = local_peer
self.libp2p_privkey = libp2p_privkey self.libp2p_privkey = libp2p_privkey
self.noise_static_key = noise_static_key self.noise_static_key = noise_static_key
self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
noise_state = self.create_noise_state() noise_state = self.create_noise_state()
noise_state.set_as_responder() noise_state.set_as_responder()
noise_state.start_handshake() noise_state.start_handshake()
state = noise_state.noise_protocol.handshake_state
read_writer = NoiseHandshakeReadWriter(conn, noise_state) read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# TODO: Parse and save the payload from the other side.
_ = await read_writer.read_msg()
# TODO: Send our payload. # Consume msg#1
our_payload = b"server" await read_writer.read_msg()
await read_writer.write_msg(our_payload)
# TODO: Parse and save another payload from the other side. # Send msg#2, which should include our handshake payload.
_ = await read_writer.read_msg() our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize()
await read_writer.write_msg(msg_2)
# Receive msg#3
msg_3 = await read_writer.read_msg()
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if state.rs is None:
raise NoiseStateError
remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
# TODO: Add a specific exception
if not noise_state.handshake_finished: if not noise_state.handshake_finished:
raise HandshakeHasNotFinished( raise HandshakeHasNotFinished(
"handshake done but it is not marked as finished in `noise_state`" "handshake is done but it is not marked as finished in `noise_state`"
) )
# FIXME: `remote_peer` should be derived from the messages. return NoiseConnection(
return NoiseConnection(self.local_peer, self.libp2p_privkey, None, conn, False) self.local_peer,
self.libp2p_privkey,
remote_peer_id_from_pubkey,
conn,
False,
)
async def handshake_outbound( async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn: ) -> ISecureConn:
noise_state = self.create_noise_state() noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state) read_writer = NoiseHandshakeReadWriter(conn, noise_state)
noise_state.set_as_initiator() noise_state.set_as_initiator()
noise_state.start_handshake() noise_state.start_handshake()
await read_writer.write_msg(b"") state = noise_state.noise_protocol.handshake_state
# TODO: Parse and save the payload from the other side. msg_1 = b""
_ = await read_writer.read_msg() await read_writer.write_msg(msg_1)
# TODO: Send our payload. msg_2 = await read_writer.read_msg()
our_payload = b"client" peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
await read_writer.write_msg(our_payload) if state.rs is None:
raise NoiseStateError
remote_pubkey = Ed25519PublicKey.from_bytes(state.rs.public_bytes)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
raise InvalidSignature
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(
"peer id does not correspond to the received pubkey: "
f"remote_peer={remote_peer}, "
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
)
our_payload = self.make_handshake_payload()
msg_3 = our_payload.serialize()
await read_writer.write_msg(msg_3)
# TODO: Add a specific exception
if not noise_state.handshake_finished: if not noise_state.handshake_finished:
raise Exception raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
return NoiseConnection( return NoiseConnection(
self.local_peer, self.libp2p_privkey, remote_peer, conn, False self.local_peer, self.libp2p_privkey, remote_peer, conn, False

View File

@ -38,7 +38,12 @@ class Transport(ISecureTransport):
if self.with_noise_pipes: if self.with_noise_pipes:
raise NotImplementedError raise NotImplementedError
else: else:
return PatternXX(self.local_peer, self.libp2p_privkey, self.noise_privkey) return PatternXX(
self.local_peer,
self.libp2p_privkey,
self.noise_privkey,
self.early_data,
)
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn: async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
# TODO: SecureInbound attempts to complete a noise-libp2p handshake initiated # TODO: SecureInbound attempts to complete a noise-libp2p handshake initiated

View File

@ -29,6 +29,10 @@ from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.messages import (
NoiseHandshakePayload,
make_handshake_payload_sig,
)
from libp2p.security.noise.transport import Transport as NoiseTransport from libp2p.security.noise.transport import Transport as NoiseTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -73,6 +77,17 @@ def noise_static_key_factory() -> PrivateKey:
return create_ed25519_key_pair().private_key return create_ed25519_key_pair().private_key
def noise_handshake_payload_factory() -> NoiseHandshakePayload:
libp2p_keypair = create_secp256k1_key_pair()
noise_static_privkey = noise_static_key_factory()
return NoiseHandshakePayload(
libp2p_keypair.public_key,
make_handshake_payload_sig(
libp2p_keypair.private_key, noise_static_privkey.get_public_key()
),
)
def noise_transport_factory() -> NoiseTransport: def noise_transport_factory() -> NoiseTransport:
return NoiseTransport( return NoiseTransport(
libp2p_keypair=create_secp256k1_key_pair(), libp2p_keypair=create_secp256k1_key_pair(),
@ -118,7 +133,7 @@ async def noise_conn_factory(
async def upgrade_local_conn() -> None: async def upgrade_local_conn() -> None:
nonlocal local_secure_conn nonlocal local_secure_conn
local_secure_conn = await local_transport.secure_outbound( local_secure_conn = await local_transport.secure_outbound(
local_conn, local_transport.local_peer local_conn, remote_transport.local_peer
) )
async def upgrade_remote_conn() -> None: async def upgrade_remote_conn() -> None:

View File

@ -1,6 +1,7 @@
import pytest import pytest
from libp2p.tools.factories import noise_conn_factory from libp2p.security.noise.messages import NoiseHandshakePayload
from libp2p.tools.factories import noise_conn_factory, noise_handshake_payload_factory
DATA = b"testing_123" DATA = b"testing_123"
@ -18,3 +19,10 @@ async def test_noise_connection(nursery):
await local_conn.write(DATA) await local_conn.write(DATA)
read_data = await remote_conn.read(len(DATA)) read_data = await remote_conn.read(len(DATA))
assert read_data == DATA assert read_data == DATA
def test_noise_handshake_payload():
payload = noise_handshake_payload_factory()
payload_serialized = payload.serialize()
payload_deserialized = NoiseHandshakePayload.deserialize(payload_serialized)
assert payload == payload_deserialized