Merge pull request #332 from dmuhs/refactor/bool-params
Refactor initiator -> is_initiator and other flags/functions
This commit is contained in:
commit
eaa800c356
|
@ -7,7 +7,7 @@ from .raw_connection_interface import IRawConnection
|
|||
class RawConnection(IRawConnection):
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
initiator: bool
|
||||
is_initiator: bool
|
||||
|
||||
_drain_lock: asyncio.Lock
|
||||
|
||||
|
@ -19,7 +19,7 @@ class RawConnection(IRawConnection):
|
|||
) -> None:
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.initiator = initiator
|
||||
self.is_initiator = initiator
|
||||
|
||||
self._drain_lock = asyncio.Lock()
|
||||
|
||||
|
|
|
@ -6,4 +6,4 @@ class IRawConnection(ReadWriteCloser):
|
|||
A Raw Connection provides a Reader and a Writer
|
||||
"""
|
||||
|
||||
initiator: bool
|
||||
is_initiator: bool
|
||||
|
|
|
@ -84,14 +84,14 @@ class Multiselect(IMultiselectMuxer):
|
|||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError(error)
|
||||
|
||||
if not validate_handshake(handshake_contents):
|
||||
if not is_valid_handshake(handshake_contents):
|
||||
raise MultiselectError(
|
||||
"multiselect protocol ID mismatch: "
|
||||
f"received handshake_contents={handshake_contents}"
|
||||
)
|
||||
|
||||
|
||||
def validate_handshake(handshake_contents: str) -> bool:
|
||||
def is_valid_handshake(handshake_contents: str) -> bool:
|
||||
"""
|
||||
Determine if handshake is valid and should be confirmed
|
||||
:param handshake_contents: contents of handshake message
|
||||
|
|
|
@ -33,7 +33,7 @@ class MultiselectClient(IMultiselectClient):
|
|||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError(str(error))
|
||||
|
||||
if not validate_handshake(handshake_contents):
|
||||
if not is_valid_handshake(handshake_contents):
|
||||
raise MultiselectClientError("multiselect protocol ID mismatch")
|
||||
|
||||
async def select_one_of(
|
||||
|
@ -86,7 +86,7 @@ class MultiselectClient(IMultiselectClient):
|
|||
raise MultiselectClientError("unrecognized response: " + response)
|
||||
|
||||
|
||||
def validate_handshake(handshake_contents: str) -> bool:
|
||||
def is_valid_handshake(handshake_contents: str) -> bool:
|
||||
"""
|
||||
Determine if handshake is valid and should be confirmed
|
||||
:param handshake_contents: contents of handshake message
|
||||
|
|
|
@ -9,7 +9,7 @@ message RPC {
|
|||
repeated Message publish = 2;
|
||||
|
||||
message SubOpts {
|
||||
optional bool subscribe = 1; // subscribe or unsubcribe
|
||||
optional bool subscribe = 1; // subscribe or unsubscribe
|
||||
optional string topicid = 2;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,14 +20,14 @@ class BaseSession(ISecureConn):
|
|||
self,
|
||||
local_peer: ID,
|
||||
local_private_key: PrivateKey,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
peer_id: Optional[ID] = None,
|
||||
) -> None:
|
||||
self.local_peer = local_peer
|
||||
self.local_private_key = local_private_key
|
||||
self.remote_peer_id = peer_id
|
||||
self.remote_permanent_pubkey = None
|
||||
self.initiator = initiator
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
def get_local_peer(self) -> ID:
|
||||
return self.local_peer
|
||||
|
|
|
@ -29,10 +29,10 @@ class InsecureSession(BaseSession):
|
|||
local_peer: ID,
|
||||
local_private_key: PrivateKey,
|
||||
conn: ReadWriteCloser,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
peer_id: Optional[ID] = None,
|
||||
) -> None:
|
||||
super().__init__(local_peer, local_private_key, initiator, peer_id)
|
||||
super().__init__(local_peer, local_private_key, is_initiator, peer_id)
|
||||
self.conn = conn
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
|
@ -68,7 +68,7 @@ class InsecureSession(BaseSession):
|
|||
# Verify if the receive `ID` matches the one we originally initialize the session.
|
||||
# We only need to check it when we are the initiator, because only in that condition
|
||||
# we possibly knows the `ID` of the remote.
|
||||
if self.initiator and self.remote_peer_id != received_peer_id:
|
||||
if self.is_initiator and self.remote_peer_id != received_peer_id:
|
||||
raise HandshakeFailure(
|
||||
"remote peer sent unexpected peer ID. "
|
||||
f"expected={self.remote_peer_id} received={received_peer_id}"
|
||||
|
@ -97,7 +97,7 @@ class InsecureSession(BaseSession):
|
|||
self.remote_permanent_pubkey = received_pubkey
|
||||
# Only need to set peer's id when we don't know it before,
|
||||
# i.e. we are not the connection initiator.
|
||||
if not self.initiator:
|
||||
if not self.is_initiator:
|
||||
self.remote_peer_id = received_peer_id
|
||||
|
||||
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
|
||||
|
|
|
@ -61,9 +61,9 @@ class SecureSession(BaseSession):
|
|||
remote_peer: PeerID,
|
||||
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
||||
conn: MsgIOReadWriter,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
) -> None:
|
||||
super().__init__(local_peer, local_private_key, initiator, remote_peer)
|
||||
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
|
||||
self.conn = conn
|
||||
|
||||
self.local_encryption_parameters = local_encryption_parameters
|
||||
|
@ -371,7 +371,7 @@ def _mk_session_from(
|
|||
local_private_key: PrivateKey,
|
||||
session_parameters: SessionParameters,
|
||||
conn: MsgIOReadWriter,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
) -> SecureSession:
|
||||
key_set1, key_set2 = initialize_pair_for_encryption(
|
||||
session_parameters.local_encryption_parameters.cipher_type,
|
||||
|
@ -389,7 +389,7 @@ def _mk_session_from(
|
|||
session_parameters.remote_peer,
|
||||
key_set2,
|
||||
conn,
|
||||
initiator,
|
||||
is_initiator,
|
||||
)
|
||||
return session
|
||||
|
||||
|
@ -425,8 +425,10 @@ async def create_secure_session(
|
|||
except IOException:
|
||||
raise SecioException("connection closed")
|
||||
|
||||
initiator = remote_peer is not None
|
||||
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
|
||||
is_initiator = remote_peer is not None
|
||||
session = _mk_session_from(
|
||||
local_private_key, session_parameters, msg_io, is_initiator
|
||||
)
|
||||
|
||||
try:
|
||||
received_nonce = await _finish_handshake(session, remote_nonce)
|
||||
|
|
|
@ -76,18 +76,18 @@ class SecurityMultistream(ABC):
|
|||
return secure_conn
|
||||
|
||||
async def select_transport(
|
||||
self, conn: IRawConnection, initiator: bool
|
||||
self, conn: IRawConnection, is_initiator: bool
|
||||
) -> ISecureTransport:
|
||||
"""
|
||||
Select a transport that both us and the node on the
|
||||
other end of conn support and agree on
|
||||
:param conn: conn to choose a transport over
|
||||
:param initiator: true if we are the initiator, false otherwise
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if initiator:
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
list(self.transports.keys()), communicator
|
||||
|
|
|
@ -23,7 +23,7 @@ class IMuxedConn(ABC):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def initiator(self) -> bool:
|
||||
def is_initiator(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -68,8 +68,8 @@ class Mplex(IMuxedConn):
|
|||
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
|
||||
|
||||
@property
|
||||
def initiator(self) -> bool:
|
||||
return self.secured_conn.initiator
|
||||
def is_initiator(self) -> bool:
|
||||
return self.secured_conn.is_initiator
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -56,7 +56,7 @@ class MuxerMultistream:
|
|||
"""
|
||||
protocol: TProtocol
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.initiator:
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
)
|
||||
|
|
|
@ -33,13 +33,13 @@ class TransportUpgrader:
|
|||
pass
|
||||
|
||||
async def upgrade_security(
|
||||
self, raw_conn: IRawConnection, peer_id: ID, initiator: bool
|
||||
self, raw_conn: IRawConnection, peer_id: ID, is_initiator: bool
|
||||
) -> ISecureConn:
|
||||
"""
|
||||
Upgrade conn to a secured connection
|
||||
"""
|
||||
try:
|
||||
if initiator:
|
||||
if is_initiator:
|
||||
return await self.security_multistream.secure_outbound(
|
||||
raw_conn, peer_id
|
||||
)
|
||||
|
|
|
@ -9,11 +9,11 @@ from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
|
|||
|
||||
|
||||
class InMemoryConnection(IRawConnection):
|
||||
def __init__(self, peer, initiator=False):
|
||||
def __init__(self, peer, is_initiator=False):
|
||||
self.peer = peer
|
||||
self.recv_queue = asyncio.Queue()
|
||||
self.send_queue = asyncio.Queue()
|
||||
self.initiator = initiator
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
self.current_msg = None
|
||||
self.current_position = 0
|
||||
|
@ -73,7 +73,7 @@ async def test_create_secure_session():
|
|||
remote_key_pair = create_new_key_pair(b"b")
|
||||
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
local_conn = InMemoryConnection(local_peer, initiator=True)
|
||||
local_conn = InMemoryConnection(local_peer, is_initiator=True)
|
||||
remote_conn = InMemoryConnection(remote_peer)
|
||||
|
||||
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))
|
||||
|
|
|
@ -10,8 +10,8 @@ async def mplex_conn_pair(is_host_secure):
|
|||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
||||
is_host_secure
|
||||
)
|
||||
assert mplex_conn_0.initiator
|
||||
assert not mplex_conn_1.initiator
|
||||
assert mplex_conn_0.is_initiator
|
||||
assert not mplex_conn_1.is_initiator
|
||||
try:
|
||||
yield mplex_conn_0, mplex_conn_1
|
||||
finally:
|
||||
|
|
Loading…
Reference in New Issue
Block a user