Merge pull request #332 from dmuhs/refactor/bool-params

Refactor initiator -> is_initiator and other flags/functions
This commit is contained in:
Alex Stokes 2019-10-25 09:38:52 +08:00 committed by GitHub
commit eaa800c356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 37 additions and 35 deletions

View File

@ -7,7 +7,7 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection):
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
initiator: bool
is_initiator: bool
_drain_lock: asyncio.Lock
@ -19,7 +19,7 @@ class RawConnection(IRawConnection):
) -> None:
self.reader = reader
self.writer = writer
self.initiator = initiator
self.is_initiator = initiator
self._drain_lock = asyncio.Lock()

View File

@ -6,4 +6,4 @@ class IRawConnection(ReadWriteCloser):
A Raw Connection provides a Reader and a Writer
"""
initiator: bool
is_initiator: bool

View File

@ -84,14 +84,14 @@ class Multiselect(IMultiselectMuxer):
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
if not validate_handshake(handshake_contents):
if not is_valid_handshake(handshake_contents):
raise MultiselectError(
"multiselect protocol ID mismatch: "
f"received handshake_contents={handshake_contents}"
)
def validate_handshake(handshake_contents: str) -> bool:
def is_valid_handshake(handshake_contents: str) -> bool:
"""
Determine if handshake is valid and should be confirmed
:param handshake_contents: contents of handshake message

View File

@ -33,7 +33,7 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error))
if not validate_handshake(handshake_contents):
if not is_valid_handshake(handshake_contents):
raise MultiselectClientError("multiselect protocol ID mismatch")
async def select_one_of(
@ -86,7 +86,7 @@ class MultiselectClient(IMultiselectClient):
raise MultiselectClientError("unrecognized response: " + response)
def validate_handshake(handshake_contents: str) -> bool:
def is_valid_handshake(handshake_contents: str) -> bool:
"""
Determine if handshake is valid and should be confirmed
:param handshake_contents: contents of handshake message

View File

@ -9,7 +9,7 @@ message RPC {
repeated Message publish = 2;
message SubOpts {
optional bool subscribe = 1; // subscribe or unsubcribe
optional bool subscribe = 1; // subscribe or unsubscribe
optional string topicid = 2;
}
@ -75,4 +75,4 @@ message TopicDescriptor {
WOT = 2; // web of trust, certificates can allow publisher set to grow
}
}
}
}

View File

@ -20,14 +20,14 @@ class BaseSession(ISecureConn):
self,
local_peer: ID,
local_private_key: PrivateKey,
initiator: bool,
is_initiator: bool,
peer_id: Optional[ID] = None,
) -> None:
self.local_peer = local_peer
self.local_private_key = local_private_key
self.remote_peer_id = peer_id
self.remote_permanent_pubkey = None
self.initiator = initiator
self.is_initiator = is_initiator
def get_local_peer(self) -> ID:
return self.local_peer

View File

@ -29,10 +29,10 @@ class InsecureSession(BaseSession):
local_peer: ID,
local_private_key: PrivateKey,
conn: ReadWriteCloser,
initiator: bool,
is_initiator: bool,
peer_id: Optional[ID] = None,
) -> None:
super().__init__(local_peer, local_private_key, initiator, peer_id)
super().__init__(local_peer, local_private_key, is_initiator, peer_id)
self.conn = conn
async def write(self, data: bytes) -> int:
@ -68,7 +68,7 @@ class InsecureSession(BaseSession):
# Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote.
if self.initiator and self.remote_peer_id != received_peer_id:
if self.is_initiator and self.remote_peer_id != received_peer_id:
raise HandshakeFailure(
"remote peer sent unexpected peer ID. "
f"expected={self.remote_peer_id} received={received_peer_id}"
@ -97,7 +97,7 @@ class InsecureSession(BaseSession):
self.remote_permanent_pubkey = received_pubkey
# Only need to set peer's id when we don't know it before,
# i.e. we are not the connection initiator.
if not self.initiator:
if not self.is_initiator:
self.remote_peer_id = received_peer_id
# TODO: Store `pubkey` and `peer_id` to `PeerStore`

View File

@ -61,9 +61,9 @@ class SecureSession(BaseSession):
remote_peer: PeerID,
remote_encryption_parameters: AuthenticatedEncryptionParameters,
conn: MsgIOReadWriter,
initiator: bool,
is_initiator: bool,
) -> None:
super().__init__(local_peer, local_private_key, initiator, remote_peer)
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn
self.local_encryption_parameters = local_encryption_parameters
@ -371,7 +371,7 @@ def _mk_session_from(
local_private_key: PrivateKey,
session_parameters: SessionParameters,
conn: MsgIOReadWriter,
initiator: bool,
is_initiator: bool,
) -> SecureSession:
key_set1, key_set2 = initialize_pair_for_encryption(
session_parameters.local_encryption_parameters.cipher_type,
@ -389,7 +389,7 @@ def _mk_session_from(
session_parameters.remote_peer,
key_set2,
conn,
initiator,
is_initiator,
)
return session
@ -425,8 +425,10 @@ async def create_secure_session(
except IOException:
raise SecioException("connection closed")
initiator = remote_peer is not None
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
is_initiator = remote_peer is not None
session = _mk_session_from(
local_private_key, session_parameters, msg_io, is_initiator
)
try:
received_nonce = await _finish_handshake(session, remote_nonce)

View File

@ -76,18 +76,18 @@ class SecurityMultistream(ABC):
return secure_conn
async def select_transport(
self, conn: IRawConnection, initiator: bool
self, conn: IRawConnection, is_initiator: bool
) -> ISecureTransport:
"""
Select a transport that both us and the node on the
other end of conn support and agree on
:param conn: conn to choose a transport over
:param initiator: true if we are the initiator, false otherwise
:param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
protocol: TProtocol
communicator = MultiselectCommunicator(conn)
if initiator:
if is_initiator:
# Select protocol if initiator
protocol = await self.multiselect_client.select_one_of(
list(self.transports.keys()), communicator

View File

@ -23,7 +23,7 @@ class IMuxedConn(ABC):
@property
@abstractmethod
def initiator(self) -> bool:
def is_initiator(self) -> bool:
pass
@abstractmethod

View File

@ -68,8 +68,8 @@ class Mplex(IMuxedConn):
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
@property
def initiator(self) -> bool:
return self.secured_conn.initiator
def is_initiator(self) -> bool:
return self.secured_conn.is_initiator
async def close(self) -> None:
"""

View File

@ -56,7 +56,7 @@ class MuxerMultistream:
"""
protocol: TProtocol
communicator = MultiselectCommunicator(conn)
if conn.initiator:
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator
)

View File

@ -33,13 +33,13 @@ class TransportUpgrader:
pass
async def upgrade_security(
self, raw_conn: IRawConnection, peer_id: ID, initiator: bool
self, raw_conn: IRawConnection, peer_id: ID, is_initiator: bool
) -> ISecureConn:
"""
Upgrade conn to a secured connection
"""
try:
if initiator:
if is_initiator:
return await self.security_multistream.secure_outbound(
raw_conn, peer_id
)

View File

@ -9,11 +9,11 @@ from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
class InMemoryConnection(IRawConnection):
def __init__(self, peer, initiator=False):
def __init__(self, peer, is_initiator=False):
self.peer = peer
self.recv_queue = asyncio.Queue()
self.send_queue = asyncio.Queue()
self.initiator = initiator
self.is_initiator = is_initiator
self.current_msg = None
self.current_position = 0
@ -73,7 +73,7 @@ async def test_create_secure_session():
remote_key_pair = create_new_key_pair(b"b")
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
local_conn = InMemoryConnection(local_peer, initiator=True)
local_conn = InMemoryConnection(local_peer, is_initiator=True)
remote_conn = InMemoryConnection(remote_peer)
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))

View File

@ -10,8 +10,8 @@ async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure
)
assert mplex_conn_0.initiator
assert not mplex_conn_1.initiator
assert mplex_conn_0.is_initiator
assert not mplex_conn_1.is_initiator
try:
yield mplex_conn_0, mplex_conn_1
finally: