diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index ef39f36..e91bc30 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -20,13 +20,14 @@ class BaseSession(ISecureConn): self, local_peer: ID, local_private_key: PrivateKey, + initiator: bool, peer_id: Optional[ID] = None, ) -> None: self.local_peer = local_peer self.local_private_key = local_private_key self.remote_peer_id = peer_id self.remote_permanent_pubkey = None - self.initiator = peer_id is not None + self.initiator = initiator def get_local_peer(self) -> ID: return self.local_peer diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index c7b465d..27efc86 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -27,9 +27,10 @@ class InsecureSession(BaseSession): local_peer: ID, local_private_key: PrivateKey, conn: ReadWriteCloser, + initiator: bool, peer_id: Optional[ID] = None, ) -> None: - super().__init__(local_peer, local_private_key, peer_id) + super().__init__(local_peer, local_private_key, initiator, peer_id) self.conn = conn async def write(self, data: bytes) -> int: @@ -99,7 +100,7 @@ class InsecureTransport(BaseSecureTransport): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - session = InsecureSession(self.local_peer, self.local_private_key, conn) + session = InsecureSession(self.local_peer, self.local_private_key, conn, False) await session.run_handshake() return session @@ -110,7 +111,7 @@ class InsecureTransport(BaseSecureTransport): :return: secure connection object (that implements secure_conn_interface) """ session = InsecureSession( - self.local_peer, self.local_private_key, conn, peer_id + self.local_peer, self.local_private_key, conn, True, peer_id ) await session.run_handshake() return session diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index 2291d6a..3dc6534 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -57,8 +57,9 @@ class SecureSession(BaseSession): remote_peer: PeerID, remote_encryption_parameters: AuthenticatedEncryptionParameters, conn: MsgIOReadWriter, + initiator: bool, ) -> None: - super().__init__(local_peer, local_private_key, remote_peer) + super().__init__(local_peer, local_private_key, initiator, remote_peer) self.conn = conn self.local_encryption_parameters = local_encryption_parameters @@ -359,6 +360,7 @@ def _mk_session_from( local_private_key: PrivateKey, session_parameters: SessionParameters, conn: MsgIOReadWriter, + initiator: bool, ) -> SecureSession: key_set1, key_set2 = initialize_pair_for_encryption( session_parameters.local_encryption_parameters.cipher_type, @@ -376,6 +378,7 @@ def _mk_session_from( session_parameters.remote_peer, key_set2, conn, + initiator, ) return session @@ -406,7 +409,8 @@ async def create_secure_session( await conn.close() raise e - session = _mk_session_from(local_private_key, session_parameters, msg_io) + initiator = remote_peer is None + session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator) received_nonce = await _finish_handshake(session, remote_nonce) if received_nonce != local_nonce: