diff --git a/libp2p/security/noise/exceptions.py b/libp2p/security/noise/exceptions.py new file mode 100644 index 0000000..5e6040e --- /dev/null +++ b/libp2p/security/noise/exceptions.py @@ -0,0 +1,9 @@ +from libp2p.security.exceptions import HandshakeFailure + + +class NoiseFailure(HandshakeFailure): + pass + + +class HandshakeHasNotFinished(NoiseFailure): + pass diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py new file mode 100644 index 0000000..6a6f7be --- /dev/null +++ b/libp2p/security/noise/io.py @@ -0,0 +1,108 @@ +from abc import ABC, abstractmethod +from typing import cast + +from noise.connection import NoiseConnection as NoiseState + +from libp2p.io.abc import ReadWriter +from libp2p.io.utils import read_exactly +from libp2p.network.connection.raw_connection_interface import IRawConnection + +SIZE_NOISE_MESSAGE_LEN = 2 +MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 +SIZE_NOISE_MESSAGE_BODY_LEN = 2 +MAX_NOISE_MESSAGE_BODY_LEN = MAX_NOISE_MESSAGE_LEN - SIZE_NOISE_MESSAGE_BODY_LEN +BYTE_ORDER = "big" + +# | Noise packet | +# < 2 bytes -><- 65535 -> +# | noise msg len | noise msg | +# | body len | body | padding | +# <-2 bytes-><- max=65533 bytes -> + + +def encode_data(data: bytes, size_len: int) -> bytes: + len_data = len(data) + try: + len_bytes = len_data.to_bytes(size_len, BYTE_ORDER) + except OverflowError as e: + raise ValueError from e + return len_bytes + data + + +class MsgReader(ABC): + @abstractmethod + async def read_msg(self) -> bytes: + ... + + +class MsgWriter(ABC): + @abstractmethod + async def write_msg(self, msg: bytes) -> None: + ... + + +class MsgReadWriter(MsgReader, MsgWriter): + pass + + +# TODO: Add comments +class NoisePacketReadWriter(MsgReadWriter): + """Encode and decode the low level noise messages.""" + + read_writer: ReadWriter + + def __init__(self, read_writer: ReadWriter) -> None: + self.read_writer = read_writer + + async def read_msg(self) -> bytes: + len_bytes = await read_exactly(self.read_writer, SIZE_NOISE_MESSAGE_LEN) + len_int = int.from_bytes(len_bytes, BYTE_ORDER) + return await read_exactly(self.read_writer, len_int) + + async def write_msg(self, msg: bytes) -> None: + encoded_data = encode_data(msg, SIZE_NOISE_MESSAGE_LEN) + await self.read_writer.write(encoded_data) + + +# TODO: Add comments +def encode_msg_body(msg_body: bytes) -> bytes: + encoded_msg_body = encode_data(msg_body, SIZE_NOISE_MESSAGE_BODY_LEN) + if len(encoded_msg_body) > MAX_NOISE_MESSAGE_BODY_LEN: + raise ValueError( + f"msg_body is too long: {len(msg_body)}, " + f"maximum={MAX_NOISE_MESSAGE_BODY_LEN}" + ) + # NOTE: Improvements: + # 1. Senders *SHOULD* use a source of random data to populate the padding field. + # 2. and *may* use any length of padding that does not cause the total length of + # the Noise message to exceed 65535 bytes. + # Ref: https://github.com/libp2p/specs/tree/master/noise#encrypted-payloads + return encoded_msg_body # + padding + + +def decode_msg_body(noise_msg: bytes) -> bytes: + len_body = int.from_bytes(noise_msg[:SIZE_NOISE_MESSAGE_BODY_LEN], BYTE_ORDER) + # Just ignore the padding + return noise_msg[ + SIZE_NOISE_MESSAGE_BODY_LEN : (SIZE_NOISE_MESSAGE_BODY_LEN + len_body) + ] + + +# TODO: Add comments +class NoiseHandshakeReadWriter(MsgReadWriter): + read_writer: MsgReadWriter + noise_state: NoiseState + + def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: + self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn)) + self.noise_state = noise_state + + async def write_msg(self, data: bytes) -> None: + noise_msg = encode_msg_body(data) + data_encrypted = self.noise_state.write_message(noise_msg) + await self.read_writer.write_msg(data_encrypted) + + async def read_msg(self) -> bytes: + noise_msg_encrypted = await self.read_writer.read_msg() + noise_msg = self.noise_state.read_message(noise_msg_encrypted) + return decode_msg_body(noise_msg) diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index dee9a61..a05e807 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -9,24 +9,8 @@ from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from .connection import NoiseConnection - -# FIXME: Choose a serious bound number. -NUM_BYTES_TO_READ = 2048 - - -# TODO: Merged into `BasePattern`? -class PreHandshakeConnection: - conn: IRawConnection - - def __init__(self, conn: IRawConnection) -> None: - self.conn = conn - - async def write_msg(self, data: bytes) -> None: - # TODO: - await self.conn.write(data) - - async def read_msg(self) -> bytes: - return await self.conn.read(NUM_BYTES_TO_READ) +from .exceptions import HandshakeHasNotFinished +from .io import NoiseHandshakeReadWriter class IPattern(ABC): @@ -66,25 +50,24 @@ class PatternXX(BasePattern): async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: noise_state = self.create_noise_state() - handshake_conn = PreHandshakeConnection(conn) noise_state.set_as_responder() noise_state.start_handshake() - msg_0_encrypted = await handshake_conn.read_msg() + read_writer = NoiseHandshakeReadWriter(conn, noise_state) # TODO: Parse and save the payload from the other side. - _ = noise_state.read_message(msg_0_encrypted) + _ = await read_writer.read_msg() # TODO: Send our payload. our_payload = b"server" - msg_1_encrypted = noise_state.write_message(our_payload) - await handshake_conn.write_msg(msg_1_encrypted) + await read_writer.write_msg(our_payload) - msg_2_encrypted = await handshake_conn.read_msg() # TODO: Parse and save another payload from the other side. - _ = noise_state.read_message(msg_2_encrypted) + _ = await read_writer.read_msg() # TODO: Add a specific exception if not noise_state.handshake_finished: - raise Exception + raise HandshakeHasNotFinished( + "handshake done but it is not marked as finished in `noise_state`" + ) # FIXME: `remote_peer` should be derived from the messages. return NoiseConnection(self.local_peer, self.libp2p_privkey, None, conn, False) @@ -93,19 +76,17 @@ class PatternXX(BasePattern): self, conn: IRawConnection, remote_peer: ID ) -> ISecureConn: noise_state = self.create_noise_state() - handshake_conn = PreHandshakeConnection(conn) + read_writer = NoiseHandshakeReadWriter(conn, noise_state) noise_state.set_as_initiator() noise_state.start_handshake() - msg_0 = noise_state.write_message() - await handshake_conn.write_msg(msg_0) - msg_1_encrypted = await handshake_conn.read_msg() + await read_writer.write_msg(b"") + # TODO: Parse and save the payload from the other side. - _ = noise_state.read_message(msg_1_encrypted) + _ = await read_writer.read_msg() # TODO: Send our payload. our_payload = b"client" - msg_2_encrypted = noise_state.write_message(our_payload) - await handshake_conn.write_msg(msg_2_encrypted) + await read_writer.write_msg(our_payload) # TODO: Add a specific exception if not noise_state.handshake_finished: diff --git a/tests/security/noise/test_msg_read_writer.py b/tests/security/noise/test_msg_read_writer.py new file mode 100644 index 0000000..47e9afa --- /dev/null +++ b/tests/security/noise/test_msg_read_writer.py @@ -0,0 +1,27 @@ +import pytest + +from libp2p.security.noise.io import MAX_NOISE_MESSAGE_LEN, NoisePacketReadWriter +from libp2p.tools.factories import raw_conn_factory + + +@pytest.mark.parametrize( + "noise_msg", + (b"", b"data", pytest.param(b"A" * MAX_NOISE_MESSAGE_LEN, id="maximum length")), +) +@pytest.mark.trio +async def test_noise_msg_read_write_round_trip(nursery, noise_msg): + async with raw_conn_factory(nursery) as conns: + reader, writer = ( + NoisePacketReadWriter(conns[0]), + NoisePacketReadWriter(conns[1]), + ) + await writer.write_msg(noise_msg) + assert (await reader.read_msg()) == noise_msg + + +@pytest.mark.trio +async def test_noise_msg_write_too_long(nursery): + async with raw_conn_factory(nursery) as conns: + writer = NoisePacketReadWriter(conns[0]) + with pytest.raises(ValueError): + await writer.write_msg(b"1" * (MAX_NOISE_MESSAGE_LEN + 1)) diff --git a/tests/security/test_noise.py b/tests/security/noise/test_noise.py similarity index 100% rename from tests/security/test_noise.py rename to tests/security/noise/test_noise.py