Merge pull request #332 from dmuhs/refactor/bool-params

Refactor initiator -> is_initiator and other flags/functions
This commit is contained in:
Alex Stokes 2019-10-25 09:38:52 +08:00 committed by GitHub
commit eaa800c356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 37 additions and 35 deletions

View File

@ -7,7 +7,7 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection): class RawConnection(IRawConnection):
reader: asyncio.StreamReader reader: asyncio.StreamReader
writer: asyncio.StreamWriter writer: asyncio.StreamWriter
initiator: bool is_initiator: bool
_drain_lock: asyncio.Lock _drain_lock: asyncio.Lock
@ -19,7 +19,7 @@ class RawConnection(IRawConnection):
) -> None: ) -> None:
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
self.initiator = initiator self.is_initiator = initiator
self._drain_lock = asyncio.Lock() self._drain_lock = asyncio.Lock()

View File

@ -6,4 +6,4 @@ class IRawConnection(ReadWriteCloser):
A Raw Connection provides a Reader and a Writer A Raw Connection provides a Reader and a Writer
""" """
initiator: bool is_initiator: bool

View File

@ -84,14 +84,14 @@ class Multiselect(IMultiselectMuxer):
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError(error)
if not validate_handshake(handshake_contents): if not is_valid_handshake(handshake_contents):
raise MultiselectError( raise MultiselectError(
"multiselect protocol ID mismatch: " "multiselect protocol ID mismatch: "
f"received handshake_contents={handshake_contents}" f"received handshake_contents={handshake_contents}"
) )
def validate_handshake(handshake_contents: str) -> bool: def is_valid_handshake(handshake_contents: str) -> bool:
""" """
Determine if handshake is valid and should be confirmed Determine if handshake is valid and should be confirmed
:param handshake_contents: contents of handshake message :param handshake_contents: contents of handshake message

View File

@ -33,7 +33,7 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error)) raise MultiselectClientError(str(error))
if not validate_handshake(handshake_contents): if not is_valid_handshake(handshake_contents):
raise MultiselectClientError("multiselect protocol ID mismatch") raise MultiselectClientError("multiselect protocol ID mismatch")
async def select_one_of( async def select_one_of(
@ -86,7 +86,7 @@ class MultiselectClient(IMultiselectClient):
raise MultiselectClientError("unrecognized response: " + response) raise MultiselectClientError("unrecognized response: " + response)
def validate_handshake(handshake_contents: str) -> bool: def is_valid_handshake(handshake_contents: str) -> bool:
""" """
Determine if handshake is valid and should be confirmed Determine if handshake is valid and should be confirmed
:param handshake_contents: contents of handshake message :param handshake_contents: contents of handshake message

View File

@ -9,7 +9,7 @@ message RPC {
repeated Message publish = 2; repeated Message publish = 2;
message SubOpts { message SubOpts {
optional bool subscribe = 1; // subscribe or unsubcribe optional bool subscribe = 1; // subscribe or unsubscribe
optional string topicid = 2; optional string topicid = 2;
} }
@ -75,4 +75,4 @@ message TopicDescriptor {
WOT = 2; // web of trust, certificates can allow publisher set to grow WOT = 2; // web of trust, certificates can allow publisher set to grow
} }
} }
} }

View File

@ -20,14 +20,14 @@ class BaseSession(ISecureConn):
self, self,
local_peer: ID, local_peer: ID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
initiator: bool, is_initiator: bool,
peer_id: Optional[ID] = None, peer_id: Optional[ID] = None,
) -> None: ) -> None:
self.local_peer = local_peer self.local_peer = local_peer
self.local_private_key = local_private_key self.local_private_key = local_private_key
self.remote_peer_id = peer_id self.remote_peer_id = peer_id
self.remote_permanent_pubkey = None self.remote_permanent_pubkey = None
self.initiator = initiator self.is_initiator = is_initiator
def get_local_peer(self) -> ID: def get_local_peer(self) -> ID:
return self.local_peer return self.local_peer

View File

@ -29,10 +29,10 @@ class InsecureSession(BaseSession):
local_peer: ID, local_peer: ID,
local_private_key: PrivateKey, local_private_key: PrivateKey,
conn: ReadWriteCloser, conn: ReadWriteCloser,
initiator: bool, is_initiator: bool,
peer_id: Optional[ID] = None, peer_id: Optional[ID] = None,
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, initiator, peer_id) super().__init__(local_peer, local_private_key, is_initiator, peer_id)
self.conn = conn self.conn = conn
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
@ -68,7 +68,7 @@ class InsecureSession(BaseSession):
# Verify if the receive `ID` matches the one we originally initialize the session. # Verify if the receive `ID` matches the one we originally initialize the session.
# We only need to check it when we are the initiator, because only in that condition # We only need to check it when we are the initiator, because only in that condition
# we possibly knows the `ID` of the remote. # we possibly knows the `ID` of the remote.
if self.initiator and self.remote_peer_id != received_peer_id: if self.is_initiator and self.remote_peer_id != received_peer_id:
raise HandshakeFailure( raise HandshakeFailure(
"remote peer sent unexpected peer ID. " "remote peer sent unexpected peer ID. "
f"expected={self.remote_peer_id} received={received_peer_id}" f"expected={self.remote_peer_id} received={received_peer_id}"
@ -97,7 +97,7 @@ class InsecureSession(BaseSession):
self.remote_permanent_pubkey = received_pubkey self.remote_permanent_pubkey = received_pubkey
# Only need to set peer's id when we don't know it before, # Only need to set peer's id when we don't know it before,
# i.e. we are not the connection initiator. # i.e. we are not the connection initiator.
if not self.initiator: if not self.is_initiator:
self.remote_peer_id = received_peer_id self.remote_peer_id = received_peer_id
# TODO: Store `pubkey` and `peer_id` to `PeerStore` # TODO: Store `pubkey` and `peer_id` to `PeerStore`

View File

@ -61,9 +61,9 @@ class SecureSession(BaseSession):
remote_peer: PeerID, remote_peer: PeerID,
remote_encryption_parameters: AuthenticatedEncryptionParameters, remote_encryption_parameters: AuthenticatedEncryptionParameters,
conn: MsgIOReadWriter, conn: MsgIOReadWriter,
initiator: bool, is_initiator: bool,
) -> None: ) -> None:
super().__init__(local_peer, local_private_key, initiator, remote_peer) super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn self.conn = conn
self.local_encryption_parameters = local_encryption_parameters self.local_encryption_parameters = local_encryption_parameters
@ -371,7 +371,7 @@ def _mk_session_from(
local_private_key: PrivateKey, local_private_key: PrivateKey,
session_parameters: SessionParameters, session_parameters: SessionParameters,
conn: MsgIOReadWriter, conn: MsgIOReadWriter,
initiator: bool, is_initiator: bool,
) -> SecureSession: ) -> SecureSession:
key_set1, key_set2 = initialize_pair_for_encryption( key_set1, key_set2 = initialize_pair_for_encryption(
session_parameters.local_encryption_parameters.cipher_type, session_parameters.local_encryption_parameters.cipher_type,
@ -389,7 +389,7 @@ def _mk_session_from(
session_parameters.remote_peer, session_parameters.remote_peer,
key_set2, key_set2,
conn, conn,
initiator, is_initiator,
) )
return session return session
@ -425,8 +425,10 @@ async def create_secure_session(
except IOException: except IOException:
raise SecioException("connection closed") raise SecioException("connection closed")
initiator = remote_peer is not None is_initiator = remote_peer is not None
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator) session = _mk_session_from(
local_private_key, session_parameters, msg_io, is_initiator
)
try: try:
received_nonce = await _finish_handshake(session, remote_nonce) received_nonce = await _finish_handshake(session, remote_nonce)

View File

@ -76,18 +76,18 @@ class SecurityMultistream(ABC):
return secure_conn return secure_conn
async def select_transport( async def select_transport(
self, conn: IRawConnection, initiator: bool self, conn: IRawConnection, is_initiator: bool
) -> ISecureTransport: ) -> ISecureTransport:
""" """
Select a transport that both us and the node on the Select a transport that both us and the node on the
other end of conn support and agree on other end of conn support and agree on
:param conn: conn to choose a transport over :param conn: conn to choose a transport over
:param initiator: true if we are the initiator, false otherwise :param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport :return: selected secure transport
""" """
protocol: TProtocol protocol: TProtocol
communicator = MultiselectCommunicator(conn) communicator = MultiselectCommunicator(conn)
if initiator: if is_initiator:
# Select protocol if initiator # Select protocol if initiator
protocol = await self.multiselect_client.select_one_of( protocol = await self.multiselect_client.select_one_of(
list(self.transports.keys()), communicator list(self.transports.keys()), communicator

View File

@ -23,7 +23,7 @@ class IMuxedConn(ABC):
@property @property
@abstractmethod @abstractmethod
def initiator(self) -> bool: def is_initiator(self) -> bool:
pass pass
@abstractmethod @abstractmethod

View File

@ -68,8 +68,8 @@ class Mplex(IMuxedConn):
self._tasks.append(asyncio.ensure_future(self.handle_incoming())) self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
@property @property
def initiator(self) -> bool: def is_initiator(self) -> bool:
return self.secured_conn.initiator return self.secured_conn.is_initiator
async def close(self) -> None: async def close(self) -> None:
""" """

View File

@ -56,7 +56,7 @@ class MuxerMultistream:
""" """
protocol: TProtocol protocol: TProtocol
communicator = MultiselectCommunicator(conn) communicator = MultiselectCommunicator(conn)
if conn.initiator: if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of( protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator tuple(self.transports.keys()), communicator
) )

View File

@ -33,13 +33,13 @@ class TransportUpgrader:
pass pass
async def upgrade_security( async def upgrade_security(
self, raw_conn: IRawConnection, peer_id: ID, initiator: bool self, raw_conn: IRawConnection, peer_id: ID, is_initiator: bool
) -> ISecureConn: ) -> ISecureConn:
""" """
Upgrade conn to a secured connection Upgrade conn to a secured connection
""" """
try: try:
if initiator: if is_initiator:
return await self.security_multistream.secure_outbound( return await self.security_multistream.secure_outbound(
raw_conn, peer_id raw_conn, peer_id
) )

View File

@ -9,11 +9,11 @@ from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
class InMemoryConnection(IRawConnection): class InMemoryConnection(IRawConnection):
def __init__(self, peer, initiator=False): def __init__(self, peer, is_initiator=False):
self.peer = peer self.peer = peer
self.recv_queue = asyncio.Queue() self.recv_queue = asyncio.Queue()
self.send_queue = asyncio.Queue() self.send_queue = asyncio.Queue()
self.initiator = initiator self.is_initiator = is_initiator
self.current_msg = None self.current_msg = None
self.current_position = 0 self.current_position = 0
@ -73,7 +73,7 @@ async def test_create_secure_session():
remote_key_pair = create_new_key_pair(b"b") remote_key_pair = create_new_key_pair(b"b")
remote_peer = ID.from_pubkey(remote_key_pair.public_key) remote_peer = ID.from_pubkey(remote_key_pair.public_key)
local_conn = InMemoryConnection(local_peer, initiator=True) local_conn = InMemoryConnection(local_peer, is_initiator=True)
remote_conn = InMemoryConnection(remote_peer) remote_conn = InMemoryConnection(remote_peer)
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn)) local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))

View File

@ -10,8 +10,8 @@ async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure is_host_secure
) )
assert mplex_conn_0.initiator assert mplex_conn_0.is_initiator
assert not mplex_conn_1.initiator assert not mplex_conn_1.is_initiator
try: try:
yield mplex_conn_0, mplex_conn_1 yield mplex_conn_0, mplex_conn_1
finally: finally: