diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index a7df0e9..899cc0e 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -1,7 +1,7 @@ from typing import Optional from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.io.msgio import MsgIOReadWriter from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -14,7 +14,7 @@ class BaseSession(ISecureConn): local_peer: ID local_private_key: PrivateKey - conn: IRawConnection + conn: MsgIOReadWriter remote_peer_id: ID remote_permanent_pubkey: PublicKey @@ -22,7 +22,7 @@ class BaseSession(ISecureConn): self, local_peer: ID, local_private_key: PrivateKey, - conn: IRawConnection, + conn: MsgIOReadWriter, peer_id: Optional[ID] = None, ) -> None: self.local_peer = local_peer @@ -33,8 +33,9 @@ class BaseSession(ISecureConn): self.conn = conn self.initiator = peer_id is not None - async def write(self, data: bytes) -> None: + async def write(self, data: bytes) -> int: await self.conn.write(data) + return len(data) async def read(self, n: int = -1) -> bytes: return await self.conn.read(n) diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 8ad2e61..3cb094e 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,6 +1,7 @@ from libp2p.crypto.keys import PublicKey from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.utils import pubkey_from_protobuf +from libp2p.io.msgio import MsgIOReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession @@ -76,7 +77,8 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession(self.local_peer, self.local_private_key, conn) + msg_io = MsgIOReadWriter(conn) + session = InsecureSession(self.local_peer, self.local_private_key, msg_io) await session.run_handshake() return session @@ -86,8 +88,9 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ + msg_io = MsgIOReadWriter(conn) session = InsecureSession( - self.local_peer, self.local_private_key, conn, peer_id + self.local_peer, self.local_private_key, msg_io, peer_id ) await session.run_handshake() return session diff --git a/libp2p/utils.py b/libp2p/utils.py index a309064..e1c45fd 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -2,6 +2,7 @@ import itertools import math from libp2p.exceptions import ParseError +from libp2p.io.abc import Reader from libp2p.typing import StreamReader # Unsigned LEB128(varint codec) @@ -98,7 +99,7 @@ def encode_fixedint_prefixed(msg_bytes: bytes) -> bytes: return len_prefix + msg_bytes -async def read_fixedint_prefixed(reader: StreamReader) -> bytes: +async def read_fixedint_prefixed(reader: Reader) -> bytes: len_bytes = await reader.read(SIZE_LEN_BYTES) len_int = int.from_bytes(len_bytes, "big") return await reader.read(len_int)