diff --git a/libp2p/crypto/rsa.py b/libp2p/crypto/rsa.py index ed8c215..9b788ca 100644 --- a/libp2p/crypto/rsa.py +++ b/libp2p/crypto/rsa.py @@ -1,5 +1,7 @@ +from Crypto.Hash import SHA256 import Crypto.PublicKey.RSA as RSA from Crypto.PublicKey.RSA import RsaKey +from Crypto.Signature import pkcs1_15 from libp2p.crypto.keys import KeyPair, KeyType, PrivateKey, PublicKey @@ -20,7 +22,13 @@ class RSAPublicKey(PublicKey): return KeyType.RSA def verify(self, data: bytes, signature: bytes) -> bool: - raise NotImplementedError + h = SHA256.new(data) + try: + # NOTE: the typing in ``pycryptodome`` is wrong on the arguments to ``verify``. + pkcs1_15.new(self.impl).verify(h, signature) # type: ignore + except (ValueError, TypeError): + return False + return True class RSAPrivateKey(PrivateKey): @@ -39,7 +47,9 @@ class RSAPrivateKey(PrivateKey): return KeyType.RSA def sign(self, data: bytes) -> bytes: - raise NotImplementedError + h = SHA256.new(data) + # NOTE: the typing in ``pycryptodome`` is wrong on the arguments to ``sign``. + return pkcs1_15.new(self.impl).sign(h) # type: ignore def get_public_key(self) -> PublicKey: return RSAPublicKey(self.impl.publickey()) diff --git a/libp2p/crypto/serialization.py b/libp2p/crypto/serialization.py index 5b6b276..dedcf85 100644 --- a/libp2p/crypto/serialization.py +++ b/libp2p/crypto/serialization.py @@ -1,8 +1,10 @@ from libp2p.crypto.keys import KeyType, PrivateKey, PublicKey +from libp2p.crypto.rsa import RSAPublicKey from libp2p.crypto.secp256k1 import Secp256k1PrivateKey, Secp256k1PublicKey key_type_to_public_key_deserializer = { - KeyType.Secp256k1.value: Secp256k1PublicKey.from_bytes + KeyType.Secp256k1.value: Secp256k1PublicKey.from_bytes, + KeyType.RSA.value: RSAPublicKey.from_bytes, } key_type_to_private_key_deserializer = { diff --git a/libp2p/io/msgio.py b/libp2p/io/msgio.py index 03004b0..f60b0ff 100644 --- a/libp2p/io/msgio.py +++ b/libp2p/io/msgio.py @@ -25,16 +25,6 @@ def encode_msg_with_length(msg_bytes: bytes) -> bytes: return len_prefix + msg_bytes -# NOTE: temporary for this PR -encode = encode_msg_with_length - - -# NOTE: temporary for this PR -async def read_next_message(reader: Reader) -> bytes: - length = await read_length(reader) - return await reader.read(length) - - class MsgIOWriter(WriteCloser): write_closer: WriteCloser diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index a7df0e9..ef39f36 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -1,7 +1,6 @@ from typing import Optional from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -14,7 +13,6 @@ class BaseSession(ISecureConn): local_peer: ID local_private_key: PrivateKey - conn: IRawConnection remote_peer_id: ID remote_permanent_pubkey: PublicKey @@ -22,26 +20,14 @@ class BaseSession(ISecureConn): self, local_peer: ID, local_private_key: PrivateKey, - conn: IRawConnection, peer_id: Optional[ID] = None, ) -> None: self.local_peer = local_peer self.local_private_key = local_private_key self.remote_peer_id = peer_id self.remote_permanent_pubkey = None - - self.conn = conn self.initiator = peer_id is not None - async def write(self, data: bytes) -> None: - await self.conn.write(data) - - async def read(self, n: int = -1) -> bytes: - return await self.conn.read(n) - - async def close(self) -> None: - await self.conn.close() - def get_local_peer(self) -> ID: return self.local_peer diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 8ad2e61..c7b465d 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,6 +1,9 @@ -from libp2p.crypto.keys import PublicKey +from typing import Optional + +from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.utils import pubkey_from_protobuf +from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession @@ -19,6 +22,26 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0") class InsecureSession(BaseSession): + def __init__( + self, + local_peer: ID, + local_private_key: PrivateKey, + conn: ReadWriteCloser, + peer_id: Optional[ID] = None, + ) -> None: + super().__init__(local_peer, local_private_key, peer_id) + self.conn = conn + + async def write(self, data: bytes) -> int: + await self.conn.write(data) + return len(data) + + async def read(self, n: int = -1) -> bytes: + return await self.conn.read(n) + + async def close(self) -> None: + await self.conn.close() + async def run_handshake(self) -> None: msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index e2259e8..2291d6a 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import io import itertools from typing import Optional, Tuple @@ -15,8 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key -from libp2p.io.msgio import encode as encode_message -from libp2p.io.msgio import read_next_message +from libp2p.io.msgio import MsgIOReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID from libp2p.security.base_session import BaseSession @@ -45,6 +45,10 @@ DEFAULT_SUPPORTED_HASHES = "SHA256" class SecureSession(BaseSession): + buf: io.BytesIO + low_watermark: int + high_watermark: int + def __init__( self, local_peer: PeerID, @@ -52,38 +56,80 @@ class SecureSession(BaseSession): local_encryption_parameters: AuthenticatedEncryptionParameters, remote_peer: PeerID, remote_encryption_parameters: AuthenticatedEncryptionParameters, - conn: IRawConnection, + conn: MsgIOReadWriter, ) -> None: - super().__init__(local_peer, local_private_key, conn, remote_peer) + super().__init__(local_peer, local_private_key, remote_peer) + self.conn = conn self.local_encryption_parameters = local_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters self._initialize_authenticated_encryption_for_local_peer() self._initialize_authenticated_encryption_for_remote_peer() + self._reset_internal_buffer() + def _initialize_authenticated_encryption_for_local_peer(self) -> None: self.local_encrypter = Encrypter(self.local_encryption_parameters) def _initialize_authenticated_encryption_for_remote_peer(self) -> None: self.remote_encrypter = Encrypter(self.remote_encryption_parameters) - async def read(self, n: int = -1) -> bytes: - return await self._read_msg() + async def next_msg_len(self) -> int: + return await self.conn.next_msg_len() - async def _read_msg(self) -> bytes: - # TODO do we need to serialize reads? - msg = await read_next_message(self.conn) + def _reset_internal_buffer(self) -> None: + self.buf = io.BytesIO() + self.low_watermark = 0 + self.high_watermark = 0 + + def _drain(self, n: int) -> bytes: + if self.low_watermark == self.high_watermark: + return bytes() + + data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] + + if n < 0: + n = len(data) + result = data[:n].tobytes() + self.low_watermark += len(result) + + if self.low_watermark == self.high_watermark: + del data # free the memoryview so we can free the underlying BytesIO + self.buf.close() + self._reset_internal_buffer() + return result + + async def _fill(self) -> None: + msg = await self.read_msg() + self.buf.write(msg) + self.low_watermark = 0 + self.high_watermark = len(msg) + + async def read(self, n: int = -1) -> bytes: + data_from_buffer = self._drain(n) + if len(data_from_buffer) > 0: + return data_from_buffer + + next_length = await self.next_msg_len() + + if n < next_length: + await self._fill() + return self._drain(n) + else: + return await self.read_msg() + + async def read_msg(self) -> bytes: + msg = await self.conn.read_msg() return self.remote_encrypter.decrypt_if_valid(msg) - async def write(self, data: bytes) -> None: - await self._write_msg(data) + async def write(self, data: bytes) -> int: + await self.write_msg(data) + return len(data) - async def _write_msg(self, data: bytes) -> None: - # TODO do we need to serialize writes? - encrypted_data = self.local_encrypter.encrypt(data) + async def write_msg(self, msg: bytes) -> None: + encrypted_data = self.local_encrypter.encrypt(msg) tag = self.local_encrypter.authenticate(encrypted_data) - msg = encode_message(encrypted_data + tag) - await self.conn.write(msg) + await self.conn.write_msg(encrypted_data + tag) @dataclass(frozen=True) @@ -156,9 +202,9 @@ class SessionParameters: pass -async def _response_to_msg(conn: IRawConnection, msg: bytes) -> bytes: - await conn.write(encode_message(msg)) - return await read_next_message(conn) +async def _response_to_msg(read_writer: MsgIOReadWriter, msg: bytes) -> bytes: + await read_writer.write_msg(msg) + return await read_writer.read_msg() def _mk_multihash_sha256(data: bytes) -> bytes: @@ -220,7 +266,7 @@ async def _establish_session_parameters( local_peer: PeerID, local_private_key: PrivateKey, remote_peer: Optional[PeerID], - conn: IRawConnection, + conn: MsgIOReadWriter, nonce: bytes, ) -> Tuple[SessionParameters, bytes]: # establish shared encryption parameters @@ -312,7 +358,7 @@ async def _establish_session_parameters( def _mk_session_from( local_private_key: PrivateKey, session_parameters: SessionParameters, - conn: IRawConnection, + conn: MsgIOReadWriter, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( session_parameters.local_encryption_parameters.cipher_type, @@ -334,9 +380,9 @@ def _mk_session_from( return session -async def _finish_handshake(session: ISecureConn, remote_nonce: bytes) -> bytes: - await session.write(remote_nonce) - return await session.read() +async def _finish_handshake(session: SecureSession, remote_nonce: bytes) -> bytes: + await session.write_msg(remote_nonce) + return await session.read_msg() async def create_secure_session( @@ -351,15 +397,16 @@ async def create_secure_session( If successful, return an object that provides secure communication to the ``remote_peer``. """ + msg_io = MsgIOReadWriter(conn) try: session_parameters, remote_nonce = await _establish_session_parameters( - local_peer, local_private_key, remote_peer, conn, local_nonce + local_peer, local_private_key, remote_peer, msg_io, local_nonce ) except SecioException as e: await conn.close() raise e - session = _mk_session_from(local_private_key, session_parameters, conn) + session = _mk_session_from(local_private_key, session_parameters, msg_io) received_nonce = await _finish_handshake(session, remote_nonce) if received_nonce != local_nonce: diff --git a/libp2p/utils.py b/libp2p/utils.py index a309064..e1c45fd 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -2,6 +2,7 @@ import itertools import math from libp2p.exceptions import ParseError +from libp2p.io.abc import Reader from libp2p.typing import StreamReader # Unsigned LEB128(varint codec) @@ -98,7 +99,7 @@ def encode_fixedint_prefixed(msg_bytes: bytes) -> bytes: return len_prefix + msg_bytes -async def read_fixedint_prefixed(reader: StreamReader) -> bytes: +async def read_fixedint_prefixed(reader: Reader) -> bytes: len_bytes = await reader.read(SIZE_LEN_BYTES) len_int = int.from_bytes(len_bytes, "big") return await reader.read(len_int)