Security: SecureSession

Make security sessions(secio, noise) share the same implementation
`BaseSession` to avoid duplicate implementation of buffered read.
This commit is contained in:
mhchia 2020-02-17 23:33:45 +08:00
parent 2df47a943c
commit 3c2e835725
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
8 changed files with 150 additions and 196 deletions

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
class Closer(ABC):
@abstractmethod
async def close(self) -> None:
...
@ -39,10 +40,6 @@ class MsgReader(ABC):
async def read_msg(self) -> bytes:
...
# @abstractmethod
# async def next_msg_len(self) -> int:
# ...
class MsgWriter(ABC):
@abstractmethod
@ -50,7 +47,7 @@ class MsgWriter(ABC):
...
class MsgReadWriter(MsgReader, MsgWriter):
class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
pass
@ -64,5 +61,5 @@ class Encrypter(ABC):
...
class EncryptedMsgReadWriter(MsgReadWriter, Encrypter):
pass
class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter):
"""Read/write message with encryption/decryption."""

View File

@ -8,10 +8,9 @@ NOTE: currently missing the capability to indicate lengths by "varint" method.
# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501
from typing import Optional
from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser
from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser
from libp2p.io.utils import read_exactly
BYTE_ORDER = "big"
@ -26,12 +25,12 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
except OverflowError:
raise ValueError(
"msg_bytes is too large for `size_len_bytes` bytes length: "
f"msg_bytes={msg_bytes}, size_len_bytes={size_len_bytes}"
f"msg_bytes={msg_bytes!r}, size_len_bytes={size_len_bytes}"
)
return len_prefix + msg_bytes
class BaseMsgReadWriter(MsgReadWriter):
class BaseMsgReadWriter(MsgReadWriteCloser):
next_length: Optional[int]
read_write_closer: ReadWriteCloser
size_len_bytes: int

View File

@ -0,0 +1,19 @@
from libp2p.io.abc import MsgReadWriteCloser, ReadWriteCloser
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
class PlaintextHandshakeReadWriter(MsgReadWriteCloser):
conn: ReadWriteCloser
def __init__(self, conn: ReadWriteCloser) -> None:
self.conn = conn
async def read_msg(self) -> bytes:
return await read_fixedint_prefixed(self.conn)
async def write_msg(self, msg: bytes) -> None:
encoded_msg_bytes = encode_fixedint_prefixed(msg)
await self.conn.write(encoded_msg_bytes)
async def close(self) -> None:
await self.conn.close()

View File

@ -13,8 +13,8 @@ from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.typing import TProtocol
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
from .io import PlaintextHandshakeReadWriter
from .pb import plaintext_pb2
# Reference: https://github.com/libp2p/go-libp2p-core/blob/master/sec/insecure/insecure.go
@ -44,60 +44,66 @@ class InsecureSession(BaseSession):
async def close(self) -> None:
await self.conn.close()
async def run_handshake(self) -> None:
"""Raise `HandshakeFailure` when handshake failed."""
msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
try:
await self.write(encoded_msg_bytes)
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
try:
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
remote_msg = plaintext_pb2.Exchange()
remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id)
async def run_handshake(
local_peer: ID,
local_private_key: PrivateKey,
conn: IRawConnection,
is_initiator: bool,
remote_peer_id: ID,
) -> ISecureConn:
"""Raise `HandshakeFailure` when handshake failed."""
msg = make_exchange_message(local_private_key.get_public_key())
msg_bytes = msg.SerializeToString()
read_writer = PlaintextHandshakeReadWriter(conn)
try:
await read_writer.write_msg(msg_bytes)
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
# Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote.
if self.is_initiator and self.remote_peer_id != received_peer_id:
raise HandshakeFailure(
"remote peer sent unexpected peer ID. "
f"expected={self.remote_peer_id} received={received_peer_id}"
)
try:
remote_msg_bytes = await read_writer.read_msg()
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
remote_msg = plaintext_pb2.Exchange()
remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id)
# Verify if the given `pubkey` matches the given `peer_id`
try:
received_pubkey = deserialize_public_key(
remote_msg.pubkey.SerializeToString()
)
except ValueError as e:
raise HandshakeFailure(
f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}"
) from e
except MissingDeserializerError as error:
raise HandshakeFailure() from error
peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey)
if peer_id_from_received_pubkey != received_peer_id:
raise HandshakeFailure(
"peer id and pubkey from the remote mismatch: "
f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, "
f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}"
)
# Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote.
if is_initiator and remote_peer_id != received_peer_id:
raise HandshakeFailure(
"remote peer sent unexpected peer ID. "
f"expected={remote_peer_id} received={received_peer_id}"
)
# Nothing is wrong. Store the `pubkey` and `peer_id` in the session.
self.remote_permanent_pubkey = received_pubkey
# Only need to set peer's id when we don't know it before,
# i.e. we are not the connection initiator.
if not self.is_initiator:
self.remote_peer_id = received_peer_id
# Verify if the given `pubkey` matches the given `peer_id`
try:
received_pubkey = deserialize_public_key(remote_msg.pubkey.SerializeToString())
except ValueError as e:
raise HandshakeFailure(
f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}"
) from e
except MissingDeserializerError as error:
raise HandshakeFailure() from error
peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey)
if peer_id_from_received_pubkey != received_peer_id:
raise HandshakeFailure(
"peer id and pubkey from the remote mismatch: "
f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, "
f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}"
)
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
secure_conn = InsecureSession(
local_peer, local_private_key, conn, is_initiator, received_peer_id
)
# Nothing is wrong. Store the `pubkey` and `peer_id` in the session.
secure_conn.remote_permanent_pubkey = received_pubkey
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
return secure_conn
class InsecureTransport(BaseSecureTransport):
@ -113,9 +119,9 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface)
"""
session = InsecureSession(self.local_peer, self.local_private_key, conn, False)
await session.run_handshake()
return session
return await run_handshake(
self.local_peer, self.local_private_key, conn, False, None
)
async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn:
"""
@ -124,11 +130,9 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface)
"""
session = InsecureSession(
return await run_handshake(
self.local_peer, self.local_private_key, conn, True, peer_id
)
await session.run_handshake()
return session
def make_exchange_message(pubkey: PublicKey) -> plaintext_pb2.Exchange:

View File

@ -1,14 +1,11 @@
from abc import ABC, abstractmethod
from typing import cast
from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import ReadWriteCloser, MsgReadWriter, EncryptedMsgReadWriter
from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser
from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length
from libp2p.io.utils import read_exactly
from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,7 +47,14 @@ def decode_msg_body(noise_msg: bytes) -> bytes:
class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgReadWriter
"""
The base implementation of noise message reader/writer.
`encrypt` and `decrypt` are not implemented here, which should be
implemented by the subclasses.
"""
read_writer: MsgReadWriteCloser
noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
@ -67,6 +71,9 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
noise_msg = self.decrypt(noise_msg_encrypted)
return decode_msg_body(noise_msg)
async def close(self) -> None:
await self.read_writer.close()
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes:

View File

@ -8,15 +8,15 @@ from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_session import SecureSession
from .connection import NoiseConnection
from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
NoiseStateError,
PeerIDMismatchesPubkey,
)
from .io import encode_msg_body, decode_msg_body, NoiseHandshakeReadWriter
from .io import NoiseHandshakeReadWriter, NoiseTransportReadWriter
from .messages import (
NoiseHandshakePayload,
make_handshake_payload_sig,
@ -56,16 +56,6 @@ class BasePattern(IPattern):
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
async def write_msg(self, conn: IRawConnection, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
return decode_msg_body(noise_msg)
class PatternXX(BasePattern):
def __init__(
@ -116,14 +106,13 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
return NoiseConnection(
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
self.local_peer,
self.libp2p_privkey,
remote_peer_id_from_pubkey,
conn,
transport_read_writer,
False,
noise_state,
)
async def handshake_outbound(
@ -171,7 +160,12 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return NoiseConnection(
self.local_peer, self.libp2p_privkey, remote_peer, conn, False, noise_state
return SecureSession(
self.local_peer,
self.libp2p_privkey,
remote_peer,
transport_read_writer,
False,
)

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
import io
import itertools
from typing import Optional, Tuple
@ -18,13 +17,14 @@ from libp2p.crypto.exceptions import MissingDeserializerError
from libp2p.crypto.key_exchange import create_ephemeral_key_pair
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.abc import EncryptedMsgReadWriter
from libp2p.io.exceptions import DecryptionFailedException, IOException
from libp2p.io.msgio import BaseMsgReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID
from libp2p.security.base_session import BaseSession
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_session import SecureSession
from libp2p.typing import TProtocol
from .exceptions import (
@ -54,30 +54,20 @@ class MsgIOReadWriter(BaseMsgReadWriter):
size_len_bytes = SIZE_SECIO_LEN_BYTES
class SecureSession(BaseSession):
buf: io.BytesIO
low_watermark: int
high_watermark: int
class SecioMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgIOReadWriter
def __init__(
self,
local_peer: PeerID,
local_private_key: PrivateKey,
local_encryption_parameters: AuthenticatedEncryptionParameters,
remote_peer: PeerID,
remote_encryption_parameters: AuthenticatedEncryptionParameters,
conn: MsgIOReadWriter,
is_initiator: bool,
read_writer: MsgIOReadWriter,
) -> None:
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn
self.local_encryption_parameters = local_encryption_parameters
self.remote_encryption_parameters = remote_encryption_parameters
self._initialize_authenticated_encryption_for_local_peer()
self._initialize_authenticated_encryption_for_remote_peer()
self._reset_internal_buffer()
self.read_writer = read_writer
def _initialize_authenticated_encryption_for_local_peer(self) -> None:
self.local_encrypter = Encrypter(self.local_encryption_parameters)
@ -85,68 +75,28 @@ class SecureSession(BaseSession):
def _initialize_authenticated_encryption_for_remote_peer(self) -> None:
self.remote_encrypter = Encrypter(self.remote_encryption_parameters)
async def next_msg_len(self) -> int:
return await self.conn.next_msg_len()
def encrypt(self, data: bytes) -> bytes:
encrypted_data = self.local_encrypter.encrypt(data)
tag = self.local_encrypter.authenticate(encrypted_data)
return encrypted_data + tag
def _reset_internal_buffer(self) -> None:
self.buf = io.BytesIO()
self.low_watermark = 0
self.high_watermark = 0
def _drain(self, n: int) -> bytes:
if self.low_watermark == self.high_watermark:
return bytes()
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n is None:
n = len(data)
result = data[:n].tobytes()
self.low_watermark += len(result)
if self.low_watermark == self.high_watermark:
del data # free the memoryview so we can free the underlying BytesIO
self.buf.close()
self._reset_internal_buffer()
return result
async def _fill(self) -> None:
msg = await self.read_msg()
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = None) -> bytes:
if n == 0:
return bytes()
data_from_buffer = self._drain(n)
if len(data_from_buffer) > 0:
return data_from_buffer
next_length = await self.next_msg_len()
if n < next_length:
await self._fill()
return self._drain(n)
else:
return await self.read_msg()
async def read_msg(self) -> bytes:
msg = await self.conn.read_msg()
def decrypt(self, data: bytes) -> bytes:
try:
decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg)
decrypted_data = self.remote_encrypter.decrypt_if_valid(data)
except InvalidMACException as e:
raise DecryptionFailedException() from e
return decrypted_msg
async def write(self, data: bytes) -> None:
await self.write_msg(data)
return decrypted_data
async def write_msg(self, msg: bytes) -> None:
encrypted_data = self.local_encrypter.encrypt(msg)
tag = self.local_encrypter.authenticate(encrypted_data)
await self.conn.write_msg(encrypted_data + tag)
data_encrypted = self.encrypt(msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
msg_encrypted = await self.read_writer.read_msg()
return self.decrypt(msg_encrypted)
async def close(self) -> None:
await self.read_writer.close()
@dataclass(frozen=True)
@ -387,22 +337,20 @@ def _mk_session_from(
if session_parameters.order < 0:
key_set1, key_set2 = key_set2, key_set1
secio_read_writer = SecioMsgReadWriter(key_set1, key_set2, conn)
session = SecureSession(
session_parameters.local_peer,
local_private_key,
key_set1,
session_parameters.remote_peer,
key_set2,
conn,
secio_read_writer,
is_initiator,
)
return session
async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes:
await session.write_msg(remote_nonce)
return await session.read_msg()
await session.conn.write_msg(remote_nonce)
return await session.conn.read_msg()
async def create_secure_session(

View File

@ -1,43 +1,29 @@
import io
from noise.connection import NoiseConnection as NoiseState
from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.io.abc import EncryptedMsgReadWriter
from libp2p.peer.id import ID
from libp2p.security.base_session import BaseSession
from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter
class NoiseConnection(BaseSession):
class SecureSession(BaseSession):
buf: io.BytesIO
low_watermark: int
high_watermark: int
read_writer: IRawConnection
noise_state: NoiseState
def __init__(
self,
local_peer: ID,
local_private_key: PrivateKey,
remote_peer: ID,
conn: IRawConnection,
conn: EncryptedMsgReadWriter,
is_initiator: bool,
noise_state: NoiseState,
# remote_permanent_pubkey
) -> None:
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn
self.noise_state = noise_state
self._reset_internal_buffer()
def get_msg_read_writer(self) -> MsgReadWriter:
return NoiseTransportReadWriter(self.conn, self.noise_state)
async def close(self) -> None:
await self.conn.close()
def _reset_internal_buffer(self) -> None:
self.buf = io.BytesIO()
self.low_watermark = 0
@ -60,6 +46,11 @@ class NoiseConnection(BaseSession):
self._reset_internal_buffer()
return result
def _fill(self, msg: bytes) -> None:
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = None) -> bytes:
if n == 0:
return bytes()
@ -68,21 +59,16 @@ class NoiseConnection(BaseSession):
if len(data_from_buffer) > 0:
return data_from_buffer
msg = await self.read_msg()
msg = await self.conn.read_msg()
if n < len(msg):
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
self._fill(msg)
return self._drain(n)
else:
return msg
async def read_msg(self) -> bytes:
return await self.get_msg_read_writer().read_msg()
async def write(self, data: bytes) -> None:
await self.write_msg(data)
await self.conn.write_msg(data)
async def write_msg(self, msg: bytes) -> None:
await self.get_msg_read_writer().write_msg(msg)
async def close(self) -> None:
await self.conn.close()