Support read/write of noise msg and payload
This commit is contained in:
parent
1f881e0464
commit
8a4ebd4cbb
9
libp2p/security/noise/exceptions.py
Normal file
9
libp2p/security/noise/exceptions.py
Normal file
@ -0,0 +1,9 @@
|
||||
from libp2p.security.exceptions import HandshakeFailure
|
||||
|
||||
|
||||
class NoiseFailure(HandshakeFailure):
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeHasNotFinished(NoiseFailure):
|
||||
pass
|
108
libp2p/security/noise/io.py
Normal file
108
libp2p/security/noise/io.py
Normal file
@ -0,0 +1,108 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import cast
|
||||
|
||||
from noise.connection import NoiseConnection as NoiseState
|
||||
|
||||
from libp2p.io.abc import ReadWriter
|
||||
from libp2p.io.utils import read_exactly
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
|
||||
SIZE_NOISE_MESSAGE_LEN = 2
|
||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||
MAX_NOISE_MESSAGE_BODY_LEN = MAX_NOISE_MESSAGE_LEN - SIZE_NOISE_MESSAGE_BODY_LEN
|
||||
BYTE_ORDER = "big"
|
||||
|
||||
# | Noise packet |
|
||||
# < 2 bytes -><- 65535 ->
|
||||
# | noise msg len | noise msg |
|
||||
# | body len | body | padding |
|
||||
# <-2 bytes-><- max=65533 bytes ->
|
||||
|
||||
|
||||
def encode_data(data: bytes, size_len: int) -> bytes:
|
||||
len_data = len(data)
|
||||
try:
|
||||
len_bytes = len_data.to_bytes(size_len, BYTE_ORDER)
|
||||
except OverflowError as e:
|
||||
raise ValueError from e
|
||||
return len_bytes + data
|
||||
|
||||
|
||||
class MsgReader(ABC):
|
||||
@abstractmethod
|
||||
async def read_msg(self) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
class MsgWriter(ABC):
|
||||
@abstractmethod
|
||||
async def write_msg(self, msg: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MsgReadWriter(MsgReader, MsgWriter):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Add comments
|
||||
class NoisePacketReadWriter(MsgReadWriter):
|
||||
"""Encode and decode the low level noise messages."""
|
||||
|
||||
read_writer: ReadWriter
|
||||
|
||||
def __init__(self, read_writer: ReadWriter) -> None:
|
||||
self.read_writer = read_writer
|
||||
|
||||
async def read_msg(self) -> bytes:
|
||||
len_bytes = await read_exactly(self.read_writer, SIZE_NOISE_MESSAGE_LEN)
|
||||
len_int = int.from_bytes(len_bytes, BYTE_ORDER)
|
||||
return await read_exactly(self.read_writer, len_int)
|
||||
|
||||
async def write_msg(self, msg: bytes) -> None:
|
||||
encoded_data = encode_data(msg, SIZE_NOISE_MESSAGE_LEN)
|
||||
await self.read_writer.write(encoded_data)
|
||||
|
||||
|
||||
# TODO: Add comments
|
||||
def encode_msg_body(msg_body: bytes) -> bytes:
|
||||
encoded_msg_body = encode_data(msg_body, SIZE_NOISE_MESSAGE_BODY_LEN)
|
||||
if len(encoded_msg_body) > MAX_NOISE_MESSAGE_BODY_LEN:
|
||||
raise ValueError(
|
||||
f"msg_body is too long: {len(msg_body)}, "
|
||||
f"maximum={MAX_NOISE_MESSAGE_BODY_LEN}"
|
||||
)
|
||||
# NOTE: Improvements:
|
||||
# 1. Senders *SHOULD* use a source of random data to populate the padding field.
|
||||
# 2. and *may* use any length of padding that does not cause the total length of
|
||||
# the Noise message to exceed 65535 bytes.
|
||||
# Ref: https://github.com/libp2p/specs/tree/master/noise#encrypted-payloads
|
||||
return encoded_msg_body # + padding
|
||||
|
||||
|
||||
def decode_msg_body(noise_msg: bytes) -> bytes:
|
||||
len_body = int.from_bytes(noise_msg[:SIZE_NOISE_MESSAGE_BODY_LEN], BYTE_ORDER)
|
||||
# Just ignore the padding
|
||||
return noise_msg[
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN : (SIZE_NOISE_MESSAGE_BODY_LEN + len_body)
|
||||
]
|
||||
|
||||
|
||||
# TODO: Add comments
|
||||
class NoiseHandshakeReadWriter(MsgReadWriter):
|
||||
read_writer: MsgReadWriter
|
||||
noise_state: NoiseState
|
||||
|
||||
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
|
||||
self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn))
|
||||
self.noise_state = noise_state
|
||||
|
||||
async def write_msg(self, data: bytes) -> None:
|
||||
noise_msg = encode_msg_body(data)
|
||||
data_encrypted = self.noise_state.write_message(noise_msg)
|
||||
await self.read_writer.write_msg(data_encrypted)
|
||||
|
||||
async def read_msg(self) -> bytes:
|
||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
|
||||
return decode_msg_body(noise_msg)
|
@ -9,24 +9,8 @@ from libp2p.peer.id import ID
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
|
||||
from .connection import NoiseConnection
|
||||
|
||||
# FIXME: Choose a serious bound number.
|
||||
NUM_BYTES_TO_READ = 2048
|
||||
|
||||
|
||||
# TODO: Merged into `BasePattern`?
|
||||
class PreHandshakeConnection:
|
||||
conn: IRawConnection
|
||||
|
||||
def __init__(self, conn: IRawConnection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def write_msg(self, data: bytes) -> None:
|
||||
# TODO:
|
||||
await self.conn.write(data)
|
||||
|
||||
async def read_msg(self) -> bytes:
|
||||
return await self.conn.read(NUM_BYTES_TO_READ)
|
||||
from .exceptions import HandshakeHasNotFinished
|
||||
from .io import NoiseHandshakeReadWriter
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@ -66,25 +50,24 @@ class PatternXX(BasePattern):
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
noise_state = self.create_noise_state()
|
||||
handshake_conn = PreHandshakeConnection(conn)
|
||||
noise_state.set_as_responder()
|
||||
noise_state.start_handshake()
|
||||
msg_0_encrypted = await handshake_conn.read_msg()
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
# TODO: Parse and save the payload from the other side.
|
||||
_ = noise_state.read_message(msg_0_encrypted)
|
||||
_ = await read_writer.read_msg()
|
||||
|
||||
# TODO: Send our payload.
|
||||
our_payload = b"server"
|
||||
msg_1_encrypted = noise_state.write_message(our_payload)
|
||||
await handshake_conn.write_msg(msg_1_encrypted)
|
||||
await read_writer.write_msg(our_payload)
|
||||
|
||||
msg_2_encrypted = await handshake_conn.read_msg()
|
||||
# TODO: Parse and save another payload from the other side.
|
||||
_ = noise_state.read_message(msg_2_encrypted)
|
||||
_ = await read_writer.read_msg()
|
||||
|
||||
# TODO: Add a specific exception
|
||||
if not noise_state.handshake_finished:
|
||||
raise Exception
|
||||
raise HandshakeHasNotFinished(
|
||||
"handshake done but it is not marked as finished in `noise_state`"
|
||||
)
|
||||
|
||||
# FIXME: `remote_peer` should be derived from the messages.
|
||||
return NoiseConnection(self.local_peer, self.libp2p_privkey, None, conn, False)
|
||||
@ -93,19 +76,17 @@ class PatternXX(BasePattern):
|
||||
self, conn: IRawConnection, remote_peer: ID
|
||||
) -> ISecureConn:
|
||||
noise_state = self.create_noise_state()
|
||||
handshake_conn = PreHandshakeConnection(conn)
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
noise_state.set_as_initiator()
|
||||
noise_state.start_handshake()
|
||||
msg_0 = noise_state.write_message()
|
||||
await handshake_conn.write_msg(msg_0)
|
||||
msg_1_encrypted = await handshake_conn.read_msg()
|
||||
await read_writer.write_msg(b"")
|
||||
|
||||
# TODO: Parse and save the payload from the other side.
|
||||
_ = noise_state.read_message(msg_1_encrypted)
|
||||
_ = await read_writer.read_msg()
|
||||
|
||||
# TODO: Send our payload.
|
||||
our_payload = b"client"
|
||||
msg_2_encrypted = noise_state.write_message(our_payload)
|
||||
await handshake_conn.write_msg(msg_2_encrypted)
|
||||
await read_writer.write_msg(our_payload)
|
||||
|
||||
# TODO: Add a specific exception
|
||||
if not noise_state.handshake_finished:
|
||||
|
27
tests/security/noise/test_msg_read_writer.py
Normal file
27
tests/security/noise/test_msg_read_writer.py
Normal file
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.security.noise.io import MAX_NOISE_MESSAGE_LEN, NoisePacketReadWriter
|
||||
from libp2p.tools.factories import raw_conn_factory
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"noise_msg",
|
||||
(b"", b"data", pytest.param(b"A" * MAX_NOISE_MESSAGE_LEN, id="maximum length")),
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_noise_msg_read_write_round_trip(nursery, noise_msg):
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
reader, writer = (
|
||||
NoisePacketReadWriter(conns[0]),
|
||||
NoisePacketReadWriter(conns[1]),
|
||||
)
|
||||
await writer.write_msg(noise_msg)
|
||||
assert (await reader.read_msg()) == noise_msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_noise_msg_write_too_long(nursery):
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
writer = NoisePacketReadWriter(conns[0])
|
||||
with pytest.raises(ValueError):
|
||||
await writer.write_msg(b"1" * (MAX_NOISE_MESSAGE_LEN + 1))
|
Loading…
x
Reference in New Issue
Block a user