Restore initiator flag to BaseSession type

This commit is contained in:
Alex Stokes 2019-09-08 15:37:41 -04:00
parent f38899e26e
commit 2025a5c7f1
No known key found for this signature in database
GPG Key ID: 51CE1721B245C086
3 changed files with 12 additions and 6 deletions

View File

@ -20,13 +20,14 @@ class BaseSession(ISecureConn):
self, self,
local_peer: ID, local_peer: ID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
initiator: bool,
peer_id: Optional[ID] = None, peer_id: Optional[ID] = None,
) -> None: ) -> None:
self.local_peer = local_peer self.local_peer = local_peer
self.local_private_key = local_private_key self.local_private_key = local_private_key
self.remote_peer_id = peer_id self.remote_peer_id = peer_id
self.remote_permanent_pubkey = None self.remote_permanent_pubkey = None
self.initiator = peer_id is not None self.initiator = initiator
def get_local_peer(self) -> ID: def get_local_peer(self) -> ID:
return self.local_peer return self.local_peer

View File

@ -27,9 +27,10 @@ class InsecureSession(BaseSession):
local_peer: ID, local_peer: ID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
conn: ReadWriteCloser, conn: ReadWriteCloser,
initiator: bool,
peer_id: Optional[ID] = None, peer_id: Optional[ID] = None,
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, peer_id) super().__init__(local_peer, local_private_key, initiator, peer_id)
self.conn = conn self.conn = conn
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
@ -99,7 +100,7 @@ class InsecureTransport(BaseSecureTransport):
for an inbound connection (i.e. we are not the initiator) for an inbound connection (i.e. we are not the initiator)
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
session = InsecureSession(self.local_peer, self.local_private_key, conn) session = InsecureSession(self.local_peer, self.local_private_key, conn, False)
await session.run_handshake() await session.run_handshake()
return session return session
@ -110,7 +111,7 @@ class InsecureTransport(BaseSecureTransport):
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
session = InsecureSession( session = InsecureSession(
self.local_peer, self.local_private_key, conn, peer_id self.local_peer, self.local_private_key, conn, True, peer_id
) )
await session.run_handshake() await session.run_handshake()
return session return session

View File

@ -57,8 +57,9 @@ class SecureSession(BaseSession):
remote_peer: PeerID, remote_peer: PeerID,
remote_encryption_parameters: AuthenticatedEncryptionParameters, remote_encryption_parameters: AuthenticatedEncryptionParameters,
conn: MsgIOReadWriter, conn: MsgIOReadWriter,
initiator: bool,
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, remote_peer) super().__init__(local_peer, local_private_key, initiator, remote_peer)
self.conn = conn self.conn = conn
self.local_encryption_parameters = local_encryption_parameters self.local_encryption_parameters = local_encryption_parameters
@ -359,6 +360,7 @@ def _mk_session_from(
local_private_key: PrivateKey, local_private_key: PrivateKey,
session_parameters: SessionParameters, session_parameters: SessionParameters,
conn: MsgIOReadWriter, conn: MsgIOReadWriter,
initiator: bool,
) -> SecureSession: ) -> SecureSession:
key_set1, key_set2 = initialize_pair_for_encryption( key_set1, key_set2 = initialize_pair_for_encryption(
session_parameters.local_encryption_parameters.cipher_type, session_parameters.local_encryption_parameters.cipher_type,
@ -376,6 +378,7 @@ def _mk_session_from(
session_parameters.remote_peer, session_parameters.remote_peer,
key_set2, key_set2,
conn, conn,
initiator,
) )
return session return session
@ -406,7 +409,8 @@ async def create_secure_session(
await conn.close() await conn.close()
raise e raise e
session = _mk_session_from(local_private_key, session_parameters, msg_io) initiator = remote_peer is None
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
received_nonce = await _finish_handshake(session, remote_nonce) received_nonce = await _finish_handshake(session, remote_nonce)
if received_nonce != local_nonce: if received_nonce != local_nonce: