Restore initiator
flag to BaseSession
type
This commit is contained in:
parent
f38899e26e
commit
2025a5c7f1
|
@ -20,13 +20,14 @@ class BaseSession(ISecureConn):
|
||||||
self,
|
self,
|
||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
|
initiator: bool,
|
||||||
peer_id: Optional[ID] = None,
|
peer_id: Optional[ID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.local_private_key = local_private_key
|
self.local_private_key = local_private_key
|
||||||
self.remote_peer_id = peer_id
|
self.remote_peer_id = peer_id
|
||||||
self.remote_permanent_pubkey = None
|
self.remote_permanent_pubkey = None
|
||||||
self.initiator = peer_id is not None
|
self.initiator = initiator
|
||||||
|
|
||||||
def get_local_peer(self) -> ID:
|
def get_local_peer(self) -> ID:
|
||||||
return self.local_peer
|
return self.local_peer
|
||||||
|
|
|
@ -27,9 +27,10 @@ class InsecureSession(BaseSession):
|
||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
conn: ReadWriteCloser,
|
conn: ReadWriteCloser,
|
||||||
|
initiator: bool,
|
||||||
peer_id: Optional[ID] = None,
|
peer_id: Optional[ID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(local_peer, local_private_key, peer_id)
|
super().__init__(local_peer, local_private_key, initiator, peer_id)
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
|
@ -99,7 +100,7 @@ class InsecureTransport(BaseSecureTransport):
|
||||||
for an inbound connection (i.e. we are not the initiator)
|
for an inbound connection (i.e. we are not the initiator)
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
session = InsecureSession(self.local_peer, self.local_private_key, conn)
|
session = InsecureSession(self.local_peer, self.local_private_key, conn, False)
|
||||||
await session.run_handshake()
|
await session.run_handshake()
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
@ -110,7 +111,7 @@ class InsecureTransport(BaseSecureTransport):
|
||||||
:return: secure connection object (that implements secure_conn_interface)
|
:return: secure connection object (that implements secure_conn_interface)
|
||||||
"""
|
"""
|
||||||
session = InsecureSession(
|
session = InsecureSession(
|
||||||
self.local_peer, self.local_private_key, conn, peer_id
|
self.local_peer, self.local_private_key, conn, True, peer_id
|
||||||
)
|
)
|
||||||
await session.run_handshake()
|
await session.run_handshake()
|
||||||
return session
|
return session
|
||||||
|
|
|
@ -57,8 +57,9 @@ class SecureSession(BaseSession):
|
||||||
remote_peer: PeerID,
|
remote_peer: PeerID,
|
||||||
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
||||||
conn: MsgIOReadWriter,
|
conn: MsgIOReadWriter,
|
||||||
|
initiator: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(local_peer, local_private_key, remote_peer)
|
super().__init__(local_peer, local_private_key, initiator, remote_peer)
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
|
||||||
self.local_encryption_parameters = local_encryption_parameters
|
self.local_encryption_parameters = local_encryption_parameters
|
||||||
|
@ -359,6 +360,7 @@ def _mk_session_from(
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
session_parameters: SessionParameters,
|
session_parameters: SessionParameters,
|
||||||
conn: MsgIOReadWriter,
|
conn: MsgIOReadWriter,
|
||||||
|
initiator: bool,
|
||||||
) -> SecureSession:
|
) -> SecureSession:
|
||||||
key_set1, key_set2 = initialize_pair_for_encryption(
|
key_set1, key_set2 = initialize_pair_for_encryption(
|
||||||
session_parameters.local_encryption_parameters.cipher_type,
|
session_parameters.local_encryption_parameters.cipher_type,
|
||||||
|
@ -376,6 +378,7 @@ def _mk_session_from(
|
||||||
session_parameters.remote_peer,
|
session_parameters.remote_peer,
|
||||||
key_set2,
|
key_set2,
|
||||||
conn,
|
conn,
|
||||||
|
initiator,
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
@ -406,7 +409,8 @@ async def create_secure_session(
|
||||||
await conn.close()
|
await conn.close()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
session = _mk_session_from(local_private_key, session_parameters, msg_io)
|
initiator = remote_peer is None
|
||||||
|
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
|
||||||
|
|
||||||
received_nonce = await _finish_handshake(session, remote_nonce)
|
received_nonce = await _finish_handshake(session, remote_nonce)
|
||||||
if received_nonce != local_nonce:
|
if received_nonce != local_nonce:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user