Security: `SecureSession`

Make security sessions(secio, noise) share the same implementation
`BaseSession` to avoid duplicate implementation of buffered read.
pull/406/head
mhchia 2020-02-17 23:33:45 +08:00
parent 2df47a943c
commit 3c2e835725
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
8 changed files with 150 additions and 196 deletions

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
class Closer(ABC): class Closer(ABC):
@abstractmethod
async def close(self) -> None: async def close(self) -> None:
... ...
@ -39,10 +40,6 @@ class MsgReader(ABC):
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:
... ...
# @abstractmethod
# async def next_msg_len(self) -> int:
# ...
class MsgWriter(ABC): class MsgWriter(ABC):
@abstractmethod @abstractmethod
@ -50,7 +47,7 @@ class MsgWriter(ABC):
... ...
class MsgReadWriter(MsgReader, MsgWriter): class MsgReadWriteCloser(MsgReader, MsgWriter, Closer):
pass pass
@ -64,5 +61,5 @@ class Encrypter(ABC):
... ...
class EncryptedMsgReadWriter(MsgReadWriter, Encrypter): class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter):
pass """Read/write message with encryption/decryption."""

View File

@ -8,10 +8,9 @@ NOTE: currently missing the capability to indicate lengths by "varint" method.
# TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501 # TODO unify w/ https://github.com/libp2p/py-libp2p/blob/1aed52856f56a4b791696bbcbac31b5f9c2e88c9/libp2p/utils.py#L85-L99 # noqa: E501
from typing import Optional from typing import Optional
from libp2p.io.abc import MsgReadWriter, Reader, ReadWriteCloser from libp2p.io.abc import MsgReadWriteCloser, Reader, ReadWriteCloser
from libp2p.io.utils import read_exactly from libp2p.io.utils import read_exactly
BYTE_ORDER = "big" BYTE_ORDER = "big"
@ -26,12 +25,12 @@ def encode_msg_with_length(msg_bytes: bytes, size_len_bytes: int) -> bytes:
except OverflowError: except OverflowError:
raise ValueError( raise ValueError(
"msg_bytes is too large for `size_len_bytes` bytes length: " "msg_bytes is too large for `size_len_bytes` bytes length: "
f"msg_bytes={msg_bytes}, size_len_bytes={size_len_bytes}" f"msg_bytes={msg_bytes!r}, size_len_bytes={size_len_bytes}"
) )
return len_prefix + msg_bytes return len_prefix + msg_bytes
class BaseMsgReadWriter(MsgReadWriter): class BaseMsgReadWriter(MsgReadWriteCloser):
next_length: Optional[int] next_length: Optional[int]
read_write_closer: ReadWriteCloser read_write_closer: ReadWriteCloser
size_len_bytes: int size_len_bytes: int

View File

@ -0,0 +1,19 @@
from libp2p.io.abc import MsgReadWriteCloser, ReadWriteCloser
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
class PlaintextHandshakeReadWriter(MsgReadWriteCloser):
conn: ReadWriteCloser
def __init__(self, conn: ReadWriteCloser) -> None:
self.conn = conn
async def read_msg(self) -> bytes:
return await read_fixedint_prefixed(self.conn)
async def write_msg(self, msg: bytes) -> None:
encoded_msg_bytes = encode_fixedint_prefixed(msg)
await self.conn.write(encoded_msg_bytes)
async def close(self) -> None:
await self.conn.close()

View File

@ -13,8 +13,8 @@ from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.exceptions import HandshakeFailure from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
from .io import PlaintextHandshakeReadWriter
from .pb import plaintext_pb2 from .pb import plaintext_pb2
# Reference: https://github.com/libp2p/go-libp2p-core/blob/master/sec/insecure/insecure.go # Reference: https://github.com/libp2p/go-libp2p-core/blob/master/sec/insecure/insecure.go
@ -44,60 +44,66 @@ class InsecureSession(BaseSession):
async def close(self) -> None: async def close(self) -> None:
await self.conn.close() await self.conn.close()
async def run_handshake(self) -> None:
"""Raise `HandshakeFailure` when handshake failed."""
msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
try:
await self.write(encoded_msg_bytes)
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
try: async def run_handshake(
remote_msg_bytes = await read_fixedint_prefixed(self.conn) local_peer: ID,
except RawConnError as e: local_private_key: PrivateKey,
raise HandshakeFailure("connection closed") from e conn: IRawConnection,
remote_msg = plaintext_pb2.Exchange() is_initiator: bool,
remote_msg.ParseFromString(remote_msg_bytes) remote_peer_id: ID,
received_peer_id = ID(remote_msg.id) ) -> ISecureConn:
"""Raise `HandshakeFailure` when handshake failed."""
msg = make_exchange_message(local_private_key.get_public_key())
msg_bytes = msg.SerializeToString()
read_writer = PlaintextHandshakeReadWriter(conn)
try:
await read_writer.write_msg(msg_bytes)
except RawConnError as e:
raise HandshakeFailure("connection closed") from e
# Verify if the receive `ID` matches the one we originally initialize the session. try:
# We only need to check it when we are the initiator, because only in that condition remote_msg_bytes = await read_writer.read_msg()
# we possibly knows the `ID` of the remote. except RawConnError as e:
if self.is_initiator and self.remote_peer_id != received_peer_id: raise HandshakeFailure("connection closed") from e
raise HandshakeFailure( remote_msg = plaintext_pb2.Exchange()
"remote peer sent unexpected peer ID. " remote_msg.ParseFromString(remote_msg_bytes)
f"expected={self.remote_peer_id} received={received_peer_id}" received_peer_id = ID(remote_msg.id)
)
# Verify if the given `pubkey` matches the given `peer_id` # Verify if the receive `ID` matches the one we originally initialize the session.
try: # We only need to check it when we are the initiator, because only in that condition
received_pubkey = deserialize_public_key( # we possibly knows the `ID` of the remote.
remote_msg.pubkey.SerializeToString() if is_initiator and remote_peer_id != received_peer_id:
) raise HandshakeFailure(
except ValueError as e: "remote peer sent unexpected peer ID. "
raise HandshakeFailure( f"expected={remote_peer_id} received={received_peer_id}"
f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}" )
) from e
except MissingDeserializerError as error:
raise HandshakeFailure() from error
peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey)
if peer_id_from_received_pubkey != received_peer_id:
raise HandshakeFailure(
"peer id and pubkey from the remote mismatch: "
f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, "
f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}"
)
# Nothing is wrong. Store the `pubkey` and `peer_id` in the session. # Verify if the given `pubkey` matches the given `peer_id`
self.remote_permanent_pubkey = received_pubkey try:
# Only need to set peer's id when we don't know it before, received_pubkey = deserialize_public_key(remote_msg.pubkey.SerializeToString())
# i.e. we are not the connection initiator. except ValueError as e:
if not self.is_initiator: raise HandshakeFailure(
self.remote_peer_id = received_peer_id f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}"
) from e
except MissingDeserializerError as error:
raise HandshakeFailure() from error
peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey)
if peer_id_from_received_pubkey != received_peer_id:
raise HandshakeFailure(
"peer id and pubkey from the remote mismatch: "
f"received_peer_id={received_peer_id}, remote_pubkey={received_pubkey}, "
f"peer_id_from_received_pubkey={peer_id_from_received_pubkey}"
)
# TODO: Store `pubkey` and `peer_id` to `PeerStore` secure_conn = InsecureSession(
local_peer, local_private_key, conn, is_initiator, received_peer_id
)
# Nothing is wrong. Store the `pubkey` and `peer_id` in the session.
secure_conn.remote_permanent_pubkey = received_pubkey
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
return secure_conn
class InsecureTransport(BaseSecureTransport): class InsecureTransport(BaseSecureTransport):
@ -113,9 +119,9 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
session = InsecureSession(self.local_peer, self.local_private_key, conn, False) return await run_handshake(
await session.run_handshake() self.local_peer, self.local_private_key, conn, False, None
return session )
async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn:
""" """
@ -124,11 +130,9 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
session = InsecureSession( return await run_handshake(
self.local_peer, self.local_private_key, conn, True, peer_id self.local_peer, self.local_private_key, conn, True, peer_id
) )
await session.run_handshake()
return session
def make_exchange_message(pubkey: PublicKey) -> plaintext_pb2.Exchange: def make_exchange_message(pubkey: PublicKey) -> plaintext_pb2.Exchange:

View File

@ -1,14 +1,11 @@
from abc import ABC, abstractmethod
from typing import cast from typing import cast
from noise.connection import NoiseConnection as NoiseState from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import ReadWriteCloser, MsgReadWriter, EncryptedMsgReadWriter from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser
from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length
from libp2p.io.utils import read_exactly
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2 SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2 SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,7 +47,14 @@ def decode_msg_body(noise_msg: bytes) -> bytes:
class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgReadWriter """
The base implementation of noise message reader/writer.
`encrypt` and `decrypt` are not implemented here, which should be
implemented by the subclasses.
"""
read_writer: MsgReadWriteCloser
noise_state: NoiseState noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
@ -67,6 +71,9 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
noise_msg = self.decrypt(noise_msg_encrypted) noise_msg = self.decrypt(noise_msg_encrypted)
return decode_msg_body(noise_msg) return decode_msg_body(noise_msg)
async def close(self) -> None:
await self.read_writer.close()
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter): class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes: def encrypt(self, data: bytes) -> bytes:

View File

@ -8,15 +8,15 @@ from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_session import SecureSession
from .connection import NoiseConnection
from .exceptions import ( from .exceptions import (
HandshakeHasNotFinished, HandshakeHasNotFinished,
InvalidSignature, InvalidSignature,
NoiseStateError, NoiseStateError,
PeerIDMismatchesPubkey, PeerIDMismatchesPubkey,
) )
from .io import encode_msg_body, decode_msg_body, NoiseHandshakeReadWriter from .io import NoiseHandshakeReadWriter, NoiseTransportReadWriter
from .messages import ( from .messages import (
NoiseHandshakePayload, NoiseHandshakePayload,
make_handshake_payload_sig, make_handshake_payload_sig,
@ -56,16 +56,6 @@ class BasePattern(IPattern):
) )
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature) return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
async def write_msg(self, conn: IRawConnection, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
return decode_msg_body(noise_msg)
class PatternXX(BasePattern): class PatternXX(BasePattern):
def __init__( def __init__(
@ -116,14 +106,13 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished( raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`" "handshake is done but it is not marked as finished in `noise_state`"
) )
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return NoiseConnection( return SecureSession(
self.local_peer, self.local_peer,
self.libp2p_privkey, self.libp2p_privkey,
remote_peer_id_from_pubkey, remote_peer_id_from_pubkey,
conn, transport_read_writer,
False, False,
noise_state,
) )
async def handshake_outbound( async def handshake_outbound(
@ -171,7 +160,12 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished( raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`" "handshake is done but it is not marked as finished in `noise_state`"
) )
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return NoiseConnection( return SecureSession(
self.local_peer, self.libp2p_privkey, remote_peer, conn, False, noise_state self.local_peer,
self.libp2p_privkey,
remote_peer,
transport_read_writer,
False,
) )

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
import io
import itertools import itertools
from typing import Optional, Tuple from typing import Optional, Tuple
@ -18,13 +17,14 @@ from libp2p.crypto.exceptions import MissingDeserializerError
from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.key_exchange import create_ephemeral_key_pair
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.abc import EncryptedMsgReadWriter
from libp2p.io.exceptions import DecryptionFailedException, IOException from libp2p.io.exceptions import DecryptionFailedException, IOException
from libp2p.io.msgio import BaseMsgReadWriter from libp2p.io.msgio import BaseMsgReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID
from libp2p.security.base_session import BaseSession
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_session import SecureSession
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .exceptions import ( from .exceptions import (
@ -54,30 +54,20 @@ class MsgIOReadWriter(BaseMsgReadWriter):
size_len_bytes = SIZE_SECIO_LEN_BYTES size_len_bytes = SIZE_SECIO_LEN_BYTES
class SecureSession(BaseSession): class SecioMsgReadWriter(EncryptedMsgReadWriter):
buf: io.BytesIO read_writer: MsgIOReadWriter
low_watermark: int
high_watermark: int
def __init__( def __init__(
self, self,
local_peer: PeerID,
local_private_key: PrivateKey,
local_encryption_parameters: AuthenticatedEncryptionParameters, local_encryption_parameters: AuthenticatedEncryptionParameters,
remote_peer: PeerID,
remote_encryption_parameters: AuthenticatedEncryptionParameters, remote_encryption_parameters: AuthenticatedEncryptionParameters,
conn: MsgIOReadWriter, read_writer: MsgIOReadWriter,
is_initiator: bool,
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn
self.local_encryption_parameters = local_encryption_parameters self.local_encryption_parameters = local_encryption_parameters
self.remote_encryption_parameters = remote_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters
self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_local_peer()
self._initialize_authenticated_encryption_for_remote_peer() self._initialize_authenticated_encryption_for_remote_peer()
self.read_writer = read_writer
self._reset_internal_buffer()
def _initialize_authenticated_encryption_for_local_peer(self) -> None: def _initialize_authenticated_encryption_for_local_peer(self) -> None:
self.local_encrypter = Encrypter(self.local_encryption_parameters) self.local_encrypter = Encrypter(self.local_encryption_parameters)
@ -85,68 +75,28 @@ class SecureSession(BaseSession):
def _initialize_authenticated_encryption_for_remote_peer(self) -> None: def _initialize_authenticated_encryption_for_remote_peer(self) -> None:
self.remote_encrypter = Encrypter(self.remote_encryption_parameters) self.remote_encrypter = Encrypter(self.remote_encryption_parameters)
async def next_msg_len(self) -> int: def encrypt(self, data: bytes) -> bytes:
return await self.conn.next_msg_len() encrypted_data = self.local_encrypter.encrypt(data)
tag = self.local_encrypter.authenticate(encrypted_data)
return encrypted_data + tag
def _reset_internal_buffer(self) -> None: def decrypt(self, data: bytes) -> bytes:
self.buf = io.BytesIO()
self.low_watermark = 0
self.high_watermark = 0
def _drain(self, n: int) -> bytes:
if self.low_watermark == self.high_watermark:
return bytes()
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n is None:
n = len(data)
result = data[:n].tobytes()
self.low_watermark += len(result)
if self.low_watermark == self.high_watermark:
del data # free the memoryview so we can free the underlying BytesIO
self.buf.close()
self._reset_internal_buffer()
return result
async def _fill(self) -> None:
msg = await self.read_msg()
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = None) -> bytes:
if n == 0:
return bytes()
data_from_buffer = self._drain(n)
if len(data_from_buffer) > 0:
return data_from_buffer
next_length = await self.next_msg_len()
if n < next_length:
await self._fill()
return self._drain(n)
else:
return await self.read_msg()
async def read_msg(self) -> bytes:
msg = await self.conn.read_msg()
try: try:
decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg) decrypted_data = self.remote_encrypter.decrypt_if_valid(data)
except InvalidMACException as e: except InvalidMACException as e:
raise DecryptionFailedException() from e raise DecryptionFailedException() from e
return decrypted_msg return decrypted_data
async def write(self, data: bytes) -> None:
await self.write_msg(data)
async def write_msg(self, msg: bytes) -> None: async def write_msg(self, msg: bytes) -> None:
encrypted_data = self.local_encrypter.encrypt(msg) data_encrypted = self.encrypt(msg)
tag = self.local_encrypter.authenticate(encrypted_data) await self.read_writer.write_msg(data_encrypted)
await self.conn.write_msg(encrypted_data + tag)
async def read_msg(self) -> bytes:
msg_encrypted = await self.read_writer.read_msg()
return self.decrypt(msg_encrypted)
async def close(self) -> None:
await self.read_writer.close()
@dataclass(frozen=True) @dataclass(frozen=True)
@ -387,22 +337,20 @@ def _mk_session_from(
if session_parameters.order < 0: if session_parameters.order < 0:
key_set1, key_set2 = key_set2, key_set1 key_set1, key_set2 = key_set2, key_set1
secio_read_writer = SecioMsgReadWriter(key_set1, key_set2, conn)
session = SecureSession( session = SecureSession(
session_parameters.local_peer, session_parameters.local_peer,
local_private_key, local_private_key,
key_set1,
session_parameters.remote_peer, session_parameters.remote_peer,
key_set2, secio_read_writer,
conn,
is_initiator, is_initiator,
) )
return session return session
async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes: async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes:
await session.write_msg(remote_nonce) await session.conn.write_msg(remote_nonce)
return await session.read_msg() return await session.conn.read_msg()
async def create_secure_session( async def create_secure_session(

View File

@ -1,43 +1,29 @@
import io import io
from noise.connection import NoiseConnection as NoiseState
from libp2p.crypto.keys import PrivateKey from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.io.abc import EncryptedMsgReadWriter
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.base_session import BaseSession from libp2p.security.base_session import BaseSession
from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter
class NoiseConnection(BaseSession): class SecureSession(BaseSession):
buf: io.BytesIO buf: io.BytesIO
low_watermark: int low_watermark: int
high_watermark: int high_watermark: int
read_writer: IRawConnection
noise_state: NoiseState
def __init__( def __init__(
self, self,
local_peer: ID, local_peer: ID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
remote_peer: ID, remote_peer: ID,
conn: IRawConnection, conn: EncryptedMsgReadWriter,
is_initiator: bool, is_initiator: bool,
noise_state: NoiseState,
# remote_permanent_pubkey
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, is_initiator, remote_peer) super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn self.conn = conn
self.noise_state = noise_state
self._reset_internal_buffer() self._reset_internal_buffer()
def get_msg_read_writer(self) -> MsgReadWriter:
return NoiseTransportReadWriter(self.conn, self.noise_state)
async def close(self) -> None:
await self.conn.close()
def _reset_internal_buffer(self) -> None: def _reset_internal_buffer(self) -> None:
self.buf = io.BytesIO() self.buf = io.BytesIO()
self.low_watermark = 0 self.low_watermark = 0
@ -60,6 +46,11 @@ class NoiseConnection(BaseSession):
self._reset_internal_buffer() self._reset_internal_buffer()
return result return result
def _fill(self, msg: bytes) -> None:
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = None) -> bytes: async def read(self, n: int = None) -> bytes:
if n == 0: if n == 0:
return bytes() return bytes()
@ -68,21 +59,16 @@ class NoiseConnection(BaseSession):
if len(data_from_buffer) > 0: if len(data_from_buffer) > 0:
return data_from_buffer return data_from_buffer
msg = await self.read_msg() msg = await self.conn.read_msg()
if n < len(msg): if n < len(msg):
self.buf.write(msg) self._fill(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
return self._drain(n) return self._drain(n)
else: else:
return msg return msg
async def read_msg(self) -> bytes:
return await self.get_msg_read_writer().read_msg()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
await self.write_msg(data) await self.conn.write_msg(data)
async def write_msg(self, msg: bytes) -> None: async def close(self) -> None:
await self.get_msg_read_writer().write_msg(msg) await self.conn.close()