diff --git a/libp2p/security/base_session.py b/libp2p/security/base_session.py index cce1b6c..abc3abb 100644 --- a/libp2p/security/base_session.py +++ b/libp2p/security/base_session.py @@ -11,20 +11,22 @@ class BaseSession(ISecureConn): local_peer: ID local_private_key: PrivateKey - remote_peer_id: ID + remote_peer: ID remote_permanent_pubkey: PublicKey def __init__( self, + *, local_peer: ID, local_private_key: PrivateKey, + remote_peer: ID, + remote_permanent_pubkey: PublicKey, is_initiator: bool, - peer_id: Optional[ID] = None, ) -> None: self.local_peer = local_peer self.local_private_key = local_private_key - self.remote_peer_id = peer_id - self.remote_permanent_pubkey = None + self.remote_peer = remote_peer + self.remote_permanent_pubkey = remote_permanent_pubkey self.is_initiator = is_initiator def get_local_peer(self) -> ID: @@ -34,7 +36,7 @@ class BaseSession(ISecureConn): return self.local_private_key def get_remote_peer(self) -> ID: - return self.remote_peer_id + return self.remote_peer def get_remote_public_key(self) -> Optional[PublicKey]: return self.remote_permanent_pubkey diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 052d342..861ca71 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,5 +1,3 @@ -from typing import Optional - from libp2p.crypto.exceptions import MissingDeserializerError from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.pb import crypto_pb2 @@ -32,13 +30,21 @@ class PlaintextHandshakeReadWriter(BaseMsgReadWriter): class InsecureSession(BaseSession): def __init__( self, + *, local_peer: ID, local_private_key: PrivateKey, - conn: ReadWriteCloser, + remote_peer: ID, + remote_permanent_pubkey: PublicKey, is_initiator: bool, - peer_id: Optional[ID] = None, + conn: ReadWriteCloser, ) -> None: - super().__init__(local_peer, local_private_key, is_initiator, peer_id) + super().__init__( + local_peer=local_peer, + local_private_key=local_private_key, + remote_peer=remote_peer, + remote_permanent_pubkey=remote_permanent_pubkey, + is_initiator=is_initiator, + ) self.conn = conn async def write(self, data: bytes) -> None: @@ -102,11 +108,14 @@ async def run_handshake( ) secure_conn = InsecureSession( - local_peer, local_private_key, conn, is_initiator, received_peer_id + local_peer=local_peer, + local_private_key=local_private_key, + remote_peer=received_peer_id, + remote_permanent_pubkey=received_pubkey, + is_initiator=is_initiator, + conn=conn, ) - # Nothing is wrong. Store the `pubkey` and `peer_id` in the session. - secure_conn.remote_permanent_pubkey = received_pubkey # TODO: Store `pubkey` and `peer_id` to `PeerStore` return secure_conn diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 64d4965..cb5e7bf 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -108,11 +108,12 @@ class PatternXX(BasePattern): ) transport_read_writer = NoiseTransportReadWriter(conn, noise_state) return SecureSession( - self.local_peer, - self.libp2p_privkey, - remote_peer_id_from_pubkey, - transport_read_writer, - False, + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=remote_peer_id_from_pubkey, + remote_permanent_pubkey=remote_pubkey, + is_initiator=False, + conn=transport_read_writer, ) async def handshake_outbound( @@ -161,11 +162,11 @@ class PatternXX(BasePattern): "handshake is done but it is not marked as finished in `noise_state`" ) transport_read_writer = NoiseTransportReadWriter(conn, noise_state) - return SecureSession( - self.local_peer, - self.libp2p_privkey, - remote_peer, - transport_read_writer, - False, + local_peer=self.local_peer, + local_private_key=self.libp2p_privkey, + remote_peer=remote_peer_id_from_pubkey, + remote_permanent_pubkey=remote_pubkey, + is_initiator=True, + conn=transport_read_writer, ) diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index ba774d9..4759ccc 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -338,12 +338,16 @@ def _mk_session_from( if session_parameters.order < 0: key_set1, key_set2 = key_set2, key_set1 secio_read_writer = SecioMsgReadWriter(key_set1, key_set2, conn) + remote_permanent_pubkey = ( + session_parameters.remote_encryption_parameters.permanent_public_key + ) session = SecureSession( - session_parameters.local_peer, - local_private_key, - session_parameters.remote_peer, - secio_read_writer, - is_initiator, + local_peer=session_parameters.local_peer, + local_private_key=local_private_key, + remote_peer=session_parameters.remote_peer, + remote_permanent_pubkey=remote_permanent_pubkey, + is_initiator=is_initiator, + conn=secio_read_writer, ) return session diff --git a/libp2p/security/secure_session.py b/libp2p/security/secure_session.py index 9bbc00a..dbabd1a 100644 --- a/libp2p/security/secure_session.py +++ b/libp2p/security/secure_session.py @@ -1,6 +1,6 @@ import io -from libp2p.crypto.keys import PrivateKey +from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.io.abc import EncryptedMsgReadWriter from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession @@ -13,13 +13,21 @@ class SecureSession(BaseSession): def __init__( self, + *, local_peer: ID, local_private_key: PrivateKey, remote_peer: ID, - conn: EncryptedMsgReadWriter, + remote_permanent_pubkey: PublicKey, is_initiator: bool, + conn: EncryptedMsgReadWriter, ) -> None: - super().__init__(local_peer, local_private_key, is_initiator, remote_peer) + super().__init__( + local_peer=local_peer, + local_private_key=local_private_key, + remote_peer=remote_peer, + remote_permanent_pubkey=remote_permanent_pubkey, + is_initiator=is_initiator, + ) self.conn = conn self._reset_internal_buffer()