Merge pull request #332 from dmuhs/refactor/bool-params
Refactor initiator -> is_initiator and other flags/functions
This commit is contained in:
commit
eaa800c356
|
@ -7,7 +7,7 @@ from .raw_connection_interface import IRawConnection
|
||||||
class RawConnection(IRawConnection):
|
class RawConnection(IRawConnection):
|
||||||
reader: asyncio.StreamReader
|
reader: asyncio.StreamReader
|
||||||
writer: asyncio.StreamWriter
|
writer: asyncio.StreamWriter
|
||||||
initiator: bool
|
is_initiator: bool
|
||||||
|
|
||||||
_drain_lock: asyncio.Lock
|
_drain_lock: asyncio.Lock
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class RawConnection(IRawConnection):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self.initiator = initiator
|
self.is_initiator = initiator
|
||||||
|
|
||||||
self._drain_lock = asyncio.Lock()
|
self._drain_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
|
@ -6,4 +6,4 @@ class IRawConnection(ReadWriteCloser):
|
||||||
A Raw Connection provides a Reader and a Writer
|
A Raw Connection provides a Reader and a Writer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
initiator: bool
|
is_initiator: bool
|
||||||
|
|
|
@ -84,14 +84,14 @@ class Multiselect(IMultiselectMuxer):
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectError(error)
|
raise MultiselectError(error)
|
||||||
|
|
||||||
if not validate_handshake(handshake_contents):
|
if not is_valid_handshake(handshake_contents):
|
||||||
raise MultiselectError(
|
raise MultiselectError(
|
||||||
"multiselect protocol ID mismatch: "
|
"multiselect protocol ID mismatch: "
|
||||||
f"received handshake_contents={handshake_contents}"
|
f"received handshake_contents={handshake_contents}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_handshake(handshake_contents: str) -> bool:
|
def is_valid_handshake(handshake_contents: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine if handshake is valid and should be confirmed
|
Determine if handshake is valid and should be confirmed
|
||||||
:param handshake_contents: contents of handshake message
|
:param handshake_contents: contents of handshake message
|
||||||
|
|
|
@ -33,7 +33,7 @@ class MultiselectClient(IMultiselectClient):
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError(str(error))
|
raise MultiselectClientError(str(error))
|
||||||
|
|
||||||
if not validate_handshake(handshake_contents):
|
if not is_valid_handshake(handshake_contents):
|
||||||
raise MultiselectClientError("multiselect protocol ID mismatch")
|
raise MultiselectClientError("multiselect protocol ID mismatch")
|
||||||
|
|
||||||
async def select_one_of(
|
async def select_one_of(
|
||||||
|
@ -86,7 +86,7 @@ class MultiselectClient(IMultiselectClient):
|
||||||
raise MultiselectClientError("unrecognized response: " + response)
|
raise MultiselectClientError("unrecognized response: " + response)
|
||||||
|
|
||||||
|
|
||||||
def validate_handshake(handshake_contents: str) -> bool:
|
def is_valid_handshake(handshake_contents: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine if handshake is valid and should be confirmed
|
Determine if handshake is valid and should be confirmed
|
||||||
:param handshake_contents: contents of handshake message
|
:param handshake_contents: contents of handshake message
|
||||||
|
|
|
@ -9,7 +9,7 @@ message RPC {
|
||||||
repeated Message publish = 2;
|
repeated Message publish = 2;
|
||||||
|
|
||||||
message SubOpts {
|
message SubOpts {
|
||||||
optional bool subscribe = 1; // subscribe or unsubcribe
|
optional bool subscribe = 1; // subscribe or unsubscribe
|
||||||
optional string topicid = 2;
|
optional string topicid = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,14 @@ class BaseSession(ISecureConn):
|
||||||
self,
|
self,
|
||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
initiator: bool,
|
is_initiator: bool,
|
||||||
peer_id: Optional[ID] = None,
|
peer_id: Optional[ID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.local_private_key = local_private_key
|
self.local_private_key = local_private_key
|
||||||
self.remote_peer_id = peer_id
|
self.remote_peer_id = peer_id
|
||||||
self.remote_permanent_pubkey = None
|
self.remote_permanent_pubkey = None
|
||||||
self.initiator = initiator
|
self.is_initiator = is_initiator
|
||||||
|
|
||||||
def get_local_peer(self) -> ID:
|
def get_local_peer(self) -> ID:
|
||||||
return self.local_peer
|
return self.local_peer
|
||||||
|
|
|
@ -29,10 +29,10 @@ class InsecureSession(BaseSession):
|
||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
conn: ReadWriteCloser,
|
conn: ReadWriteCloser,
|
||||||
initiator: bool,
|
is_initiator: bool,
|
||||||
peer_id: Optional[ID] = None,
|
peer_id: Optional[ID] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(local_peer, local_private_key, initiator, peer_id)
|
super().__init__(local_peer, local_private_key, is_initiator, peer_id)
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
|
@ -68,7 +68,7 @@ class InsecureSession(BaseSession):
|
||||||
# Verify if the receive `ID` matches the one we originally initialize the session.
|
# Verify if the receive `ID` matches the one we originally initialize the session.
|
||||||
# We only need to check it when we are the initiator, because only in that condition
|
# We only need to check it when we are the initiator, because only in that condition
|
||||||
# we possibly knows the `ID` of the remote.
|
# we possibly knows the `ID` of the remote.
|
||||||
if self.initiator and self.remote_peer_id != received_peer_id:
|
if self.is_initiator and self.remote_peer_id != received_peer_id:
|
||||||
raise HandshakeFailure(
|
raise HandshakeFailure(
|
||||||
"remote peer sent unexpected peer ID. "
|
"remote peer sent unexpected peer ID. "
|
||||||
f"expected={self.remote_peer_id} received={received_peer_id}"
|
f"expected={self.remote_peer_id} received={received_peer_id}"
|
||||||
|
@ -97,7 +97,7 @@ class InsecureSession(BaseSession):
|
||||||
self.remote_permanent_pubkey = received_pubkey
|
self.remote_permanent_pubkey = received_pubkey
|
||||||
# Only need to set peer's id when we don't know it before,
|
# Only need to set peer's id when we don't know it before,
|
||||||
# i.e. we are not the connection initiator.
|
# i.e. we are not the connection initiator.
|
||||||
if not self.initiator:
|
if not self.is_initiator:
|
||||||
self.remote_peer_id = received_peer_id
|
self.remote_peer_id = received_peer_id
|
||||||
|
|
||||||
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
|
# TODO: Store `pubkey` and `peer_id` to `PeerStore`
|
||||||
|
|
|
@ -61,9 +61,9 @@ class SecureSession(BaseSession):
|
||||||
remote_peer: PeerID,
|
remote_peer: PeerID,
|
||||||
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
remote_encryption_parameters: AuthenticatedEncryptionParameters,
|
||||||
conn: MsgIOReadWriter,
|
conn: MsgIOReadWriter,
|
||||||
initiator: bool,
|
is_initiator: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(local_peer, local_private_key, initiator, remote_peer)
|
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
|
||||||
self.local_encryption_parameters = local_encryption_parameters
|
self.local_encryption_parameters = local_encryption_parameters
|
||||||
|
@ -371,7 +371,7 @@ def _mk_session_from(
|
||||||
local_private_key: PrivateKey,
|
local_private_key: PrivateKey,
|
||||||
session_parameters: SessionParameters,
|
session_parameters: SessionParameters,
|
||||||
conn: MsgIOReadWriter,
|
conn: MsgIOReadWriter,
|
||||||
initiator: bool,
|
is_initiator: bool,
|
||||||
) -> SecureSession:
|
) -> SecureSession:
|
||||||
key_set1, key_set2 = initialize_pair_for_encryption(
|
key_set1, key_set2 = initialize_pair_for_encryption(
|
||||||
session_parameters.local_encryption_parameters.cipher_type,
|
session_parameters.local_encryption_parameters.cipher_type,
|
||||||
|
@ -389,7 +389,7 @@ def _mk_session_from(
|
||||||
session_parameters.remote_peer,
|
session_parameters.remote_peer,
|
||||||
key_set2,
|
key_set2,
|
||||||
conn,
|
conn,
|
||||||
initiator,
|
is_initiator,
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
@ -425,8 +425,10 @@ async def create_secure_session(
|
||||||
except IOException:
|
except IOException:
|
||||||
raise SecioException("connection closed")
|
raise SecioException("connection closed")
|
||||||
|
|
||||||
initiator = remote_peer is not None
|
is_initiator = remote_peer is not None
|
||||||
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
|
session = _mk_session_from(
|
||||||
|
local_private_key, session_parameters, msg_io, is_initiator
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
received_nonce = await _finish_handshake(session, remote_nonce)
|
received_nonce = await _finish_handshake(session, remote_nonce)
|
||||||
|
|
|
@ -76,18 +76,18 @@ class SecurityMultistream(ABC):
|
||||||
return secure_conn
|
return secure_conn
|
||||||
|
|
||||||
async def select_transport(
|
async def select_transport(
|
||||||
self, conn: IRawConnection, initiator: bool
|
self, conn: IRawConnection, is_initiator: bool
|
||||||
) -> ISecureTransport:
|
) -> ISecureTransport:
|
||||||
"""
|
"""
|
||||||
Select a transport that both us and the node on the
|
Select a transport that both us and the node on the
|
||||||
other end of conn support and agree on
|
other end of conn support and agree on
|
||||||
:param conn: conn to choose a transport over
|
:param conn: conn to choose a transport over
|
||||||
:param initiator: true if we are the initiator, false otherwise
|
:param is_initiator: true if we are the initiator, false otherwise
|
||||||
:return: selected secure transport
|
:return: selected secure transport
|
||||||
"""
|
"""
|
||||||
protocol: TProtocol
|
protocol: TProtocol
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if initiator:
|
if is_initiator:
|
||||||
# Select protocol if initiator
|
# Select protocol if initiator
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
list(self.transports.keys()), communicator
|
list(self.transports.keys()), communicator
|
||||||
|
|
|
@ -23,7 +23,7 @@ class IMuxedConn(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -68,8 +68,8 @@ class Mplex(IMuxedConn):
|
||||||
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
|
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
return self.secured_conn.initiator
|
return self.secured_conn.is_initiator
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -56,7 +56,7 @@ class MuxerMultistream:
|
||||||
"""
|
"""
|
||||||
protocol: TProtocol
|
protocol: TProtocol
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if conn.initiator:
|
if conn.is_initiator:
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,13 +33,13 @@ class TransportUpgrader:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def upgrade_security(
|
async def upgrade_security(
|
||||||
self, raw_conn: IRawConnection, peer_id: ID, initiator: bool
|
self, raw_conn: IRawConnection, peer_id: ID, is_initiator: bool
|
||||||
) -> ISecureConn:
|
) -> ISecureConn:
|
||||||
"""
|
"""
|
||||||
Upgrade conn to a secured connection
|
Upgrade conn to a secured connection
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if initiator:
|
if is_initiator:
|
||||||
return await self.security_multistream.secure_outbound(
|
return await self.security_multistream.secure_outbound(
|
||||||
raw_conn, peer_id
|
raw_conn, peer_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,11 +9,11 @@ from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
|
||||||
|
|
||||||
|
|
||||||
class InMemoryConnection(IRawConnection):
|
class InMemoryConnection(IRawConnection):
|
||||||
def __init__(self, peer, initiator=False):
|
def __init__(self, peer, is_initiator=False):
|
||||||
self.peer = peer
|
self.peer = peer
|
||||||
self.recv_queue = asyncio.Queue()
|
self.recv_queue = asyncio.Queue()
|
||||||
self.send_queue = asyncio.Queue()
|
self.send_queue = asyncio.Queue()
|
||||||
self.initiator = initiator
|
self.is_initiator = is_initiator
|
||||||
|
|
||||||
self.current_msg = None
|
self.current_msg = None
|
||||||
self.current_position = 0
|
self.current_position = 0
|
||||||
|
@ -73,7 +73,7 @@ async def test_create_secure_session():
|
||||||
remote_key_pair = create_new_key_pair(b"b")
|
remote_key_pair = create_new_key_pair(b"b")
|
||||||
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
||||||
|
|
||||||
local_conn = InMemoryConnection(local_peer, initiator=True)
|
local_conn = InMemoryConnection(local_peer, is_initiator=True)
|
||||||
remote_conn = InMemoryConnection(remote_peer)
|
remote_conn = InMemoryConnection(remote_peer)
|
||||||
|
|
||||||
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))
|
local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn))
|
||||||
|
|
|
@ -10,8 +10,8 @@ async def mplex_conn_pair(is_host_secure):
|
||||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
||||||
is_host_secure
|
is_host_secure
|
||||||
)
|
)
|
||||||
assert mplex_conn_0.initiator
|
assert mplex_conn_0.is_initiator
|
||||||
assert not mplex_conn_1.initiator
|
assert not mplex_conn_1.is_initiator
|
||||||
try:
|
try:
|
||||||
yield mplex_conn_0, mplex_conn_1
|
yield mplex_conn_0, mplex_conn_1
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user