py-libp2p/libp2p/security/noise/io.py

127 lines
4.3 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from typing import cast
from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import ReadWriter
from libp2p.io.utils import read_exactly
from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
MAX_NOISE_MESSAGE_BODY_LEN = MAX_NOISE_MESSAGE_LEN - SIZE_NOISE_MESSAGE_BODY_LEN
BYTE_ORDER = "big"
# | Noise packet |
# < 2 bytes -><- 65535 ->
# | noise msg len | noise msg |
# | body len | body | padding |
# <-2 bytes-><- max=65533 bytes ->
def encode_data(data: bytes, size_len: int) -> bytes:
len_data = len(data)
try:
len_bytes = len_data.to_bytes(size_len, BYTE_ORDER)
except OverflowError as e:
raise ValueError from e
return len_bytes + data
class MsgReader(ABC):
@abstractmethod
async def read_msg(self) -> bytes:
...
class MsgWriter(ABC):
@abstractmethod
async def write_msg(self, msg: bytes) -> None:
...
class MsgReadWriter(MsgReader, MsgWriter):
pass
# TODO: Add comments
class NoisePacketReadWriter(MsgReadWriter):
"""Encode and decode the low level noise messages."""
read_writer: ReadWriter
def __init__(self, read_writer: ReadWriter) -> None:
self.read_writer = read_writer
async def read_msg(self) -> bytes:
len_bytes = await read_exactly(self.read_writer, SIZE_NOISE_MESSAGE_LEN)
len_int = int.from_bytes(len_bytes, BYTE_ORDER)
return await read_exactly(self.read_writer, len_int)
async def write_msg(self, msg: bytes) -> None:
encoded_data = encode_data(msg, SIZE_NOISE_MESSAGE_LEN)
await self.read_writer.write(encoded_data)
# TODO: Add comments
def encode_msg_body(msg_body: bytes) -> bytes:
encoded_msg_body = encode_data(msg_body, SIZE_NOISE_MESSAGE_BODY_LEN)
if len(encoded_msg_body) > MAX_NOISE_MESSAGE_BODY_LEN:
raise ValueError(
f"msg_body is too long: {len(msg_body)}, "
f"maximum={MAX_NOISE_MESSAGE_BODY_LEN}"
)
# NOTE: Improvements:
# 1. Senders *SHOULD* use a source of random data to populate the padding field.
# 2. and *may* use any length of padding that does not cause the total length of
# the Noise message to exceed 65535 bytes.
# Ref: https://github.com/libp2p/specs/tree/master/noise#encrypted-payloads
return encoded_msg_body # + padding
def decode_msg_body(noise_msg: bytes) -> bytes:
len_body = int.from_bytes(noise_msg[:SIZE_NOISE_MESSAGE_BODY_LEN], BYTE_ORDER)
# Just ignore the padding
return noise_msg[
SIZE_NOISE_MESSAGE_BODY_LEN : (SIZE_NOISE_MESSAGE_BODY_LEN + len_body)
]
class NoiseHandshakeReadWriter(MsgReadWriter):
read_writer: MsgReadWriter
noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn))
self.noise_state = noise_state
async def write_msg(self, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
return decode_msg_body(noise_msg)
class NoiseTransportReadWriter(MsgReadWriter):
read_writer: MsgReadWriter
noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn))
self.noise_state = noise_state
async def write_msg(self, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.encrypt(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.decrypt(noise_msg_encrypted)
return decode_msg_body(noise_msg)