diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index 899cc0e..ef39f36 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -1,7 +1,6 @@ from typing import Optional from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.io.msgio import MsgIOReadWriter from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -14,7 +13,6 @@ class BaseSession(ISecureConn): local_peer: ID local_private_key: PrivateKey - conn: MsgIOReadWriter remote_peer_id: ID remote_permanent_pubkey: PublicKey @@ -22,27 +20,14 @@ class BaseSession(ISecureConn): self, local_peer: ID, local_private_key: PrivateKey, - conn: MsgIOReadWriter, peer_id: Optional[ID] = None, ) -> None: self.local_peer = local_peer self.local_private_key = local_private_key self.remote_peer_id = peer_id self.remote_permanent_pubkey = None - - self.conn = conn self.initiator = peer_id is not None - async def write(self, data: bytes) -> int: - await self.conn.write(data) - return len(data) - - async def read(self, n: int = -1) -> bytes: - return await self.conn.read(n) - - async def close(self) -> None: - await self.conn.close() - def get_local_peer(self) -> ID: return self.local_peer diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 3cb094e..c7b465d 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,7 +1,9 @@ -from libp2p.crypto.keys import PublicKey +from typing import Optional + +from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.utils import pubkey_from_protobuf -from libp2p.io.msgio import MsgIOReadWriter +from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession @@ -20,6 +22,26 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0") class InsecureSession(BaseSession): + def __init__( + self, + local_peer: ID, + local_private_key: PrivateKey, + conn: ReadWriteCloser, + peer_id: Optional[ID] = None, + ) -> None: + super().__init__(local_peer, local_private_key, peer_id) + self.conn = conn + + async def write(self, data: bytes) -> int: + await self.conn.write(data) + return len(data) + + async def read(self, n: int = -1) -> bytes: + return await self.conn.read(n) + + async def close(self) -> None: + await self.conn.close() + async def run_handshake(self) -> None: msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() @@ -77,8 +99,7 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - msg_io = MsgIOReadWriter(conn) - session = InsecureSession(self.local_peer, self.local_private_key, msg_io) + session = InsecureSession(self.local_peer, self.local_private_key, conn) await session.run_handshake() return session @@ -88,9 +109,8 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ - msg_io = MsgIOReadWriter(conn) session = InsecureSession( - self.local_peer, self.local_private_key, msg_io, peer_id + self.local_peer, self.local_private_key, conn, peer_id ) await session.run_handshake() return session diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 10a5763..2291d6a 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -58,7 +58,8 @@ class SecureSession(BaseSession): remote_encryption_parameters: AuthenticatedEncryptionParameters, conn: MsgIOReadWriter, ) -> None: - super().__init__(local_peer, local_private_key, conn, remote_peer) + super().__init__(local_peer, local_private_key, remote_peer) + self.conn = conn self.local_encryption_parameters = local_encryption_parameters self.remote_encryption_parameters = remote_encryption_parameters