PR feedbacks

- Move exceptions to exceptions.py
- Raise `UpgradeFailure` in upgrader
- Refine the try/catch for upgraders in swarm
This commit is contained in:
mhchia 2019-08-21 23:04:59 +08:00
parent 3e04480d62
commit 16a4fd33c1
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 97 additions and 66 deletions

View File

@ -0,0 +1,5 @@
from libp2p.exceptions import BaseLibp2pError
class SwarmException(BaseLibp2pError):
pass

View File

@ -33,7 +33,7 @@ class INetwork(ABC):
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id
:param peer_id: peer if we want to dial :param peer_id: peer if we want to dial
:raises SwarmException: raised when no address if found for peer_id :raises SwarmException: raised when an error occurs
:return: muxed connection :return: muxed connection
""" """

View File

@ -10,13 +10,14 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.transport.exceptions import UpgradeFailure from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .connection.raw_connection import RawConnection from .connection.raw_connection import RawConnection
from .exceptions import SwarmException
from .network_interface import INetwork from .network_interface import INetwork
from .notifee_interface import INotifee from .notifee_interface import INotifee
from .stream.net_stream import NetStream from .stream.net_stream import NetStream
@ -85,7 +86,7 @@ class Swarm(INetwork):
""" """
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id
:param peer_id: peer if we want to dial :param peer_id: peer if we want to dial
:raises SwarmException: raised when no address if found for peer_id :raises SwarmException: raised when an error occurs
:return: muxed connection :return: muxed connection
""" """
@ -111,10 +112,26 @@ class Swarm(INetwork):
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn # the conn and then mux the conn
secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True) try:
secured_conn = await self.upgrader.upgrade_security(
raw_conn, peer_id, True
)
except SecurityUpgradeFailure as error:
# TODO: Add logging to indicate the failure
raw_conn.close()
raise SwarmException(
f"fail to upgrade the connection to a secured connection from {peer_id}"
) from error
try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, self.generic_protocol_handler, peer_id secured_conn, self.generic_protocol_handler, peer_id
) )
except MuxerUpgradeFailure as error:
# TODO: Add logging to indicate the failure
secured_conn.close()
raise SwarmException(
f"fail to upgrade the connection to a muxed connection from {peer_id}"
) from error
# Store muxed connection in connections # Store muxed connection in connections
self.connections[peer_id] = muxed_conn self.connections[peer_id] = muxed_conn
@ -197,19 +214,28 @@ class Swarm(INetwork):
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn # the conn and then mux the conn
# FIXME: This dummy `ID(b"")` for the remote peer is useless.
try: try:
# FIXME: This dummy `ID(b"")` for the remote peer is useless.
secured_conn = await self.upgrader.upgrade_security( secured_conn = await self.upgrader.upgrade_security(
raw_conn, ID(b""), False raw_conn, ID(b""), False
) )
except SecurityUpgradeFailure as error:
# TODO: Add logging to indicate the failure
raw_conn.close()
raise SwarmException(
"fail to upgrade the connection to a secured connection"
) from error
peer_id = secured_conn.get_remote_peer() peer_id = secured_conn.get_remote_peer()
try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, self.generic_protocol_handler, peer_id secured_conn, self.generic_protocol_handler, peer_id
) )
except UpgradeFailure: except MuxerUpgradeFailure as error:
# TODO: Add logging to indicate the failure # TODO: Add logging to indicate the failure
raw_conn.close() secured_conn.close()
return raise SwarmException(
f"fail to upgrade the connection to a muxed connection from {peer_id}"
) from error
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[peer_id] = muxed_conn self.connections[peer_id] = muxed_conn
@ -283,7 +309,3 @@ def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:
asyncio.ensure_future(handler(net_stream)) asyncio.ensure_future(handler(net_stream))
return generic_protocol_handler return generic_protocol_handler
class SwarmException(Exception):
pass

View File

@ -0,0 +1,9 @@
from libp2p.exceptions import BaseLibp2pError
class MultiselectError(BaseLibp2pError):
"""Raised when an error occurs in multiselect process"""
class MultiselectClientError(BaseLibp2pError):
"""Raised when an error occurs in protocol selection process"""

View File

@ -2,6 +2,7 @@ from typing import Dict, Tuple
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .exceptions import MultiselectError
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
from .multiselect_muxer_interface import IMultiselectMuxer from .multiselect_muxer_interface import IMultiselectMuxer
@ -97,7 +98,3 @@ def validate_handshake(handshake_contents: str) -> bool:
# TODO: Modify this when format used by go repo for messages # TODO: Modify this when format used by go repo for messages
# is added # is added
return handshake_contents == MULTISELECT_PROTOCOL_ID return handshake_contents == MULTISELECT_PROTOCOL_ID
class MultiselectError(ValueError):
"""Raised when an error occurs in multiselect process"""

View File

@ -2,6 +2,7 @@ from typing import Sequence
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .exceptions import MultiselectClientError
from .multiselect_client_interface import IMultiselectClient from .multiselect_client_interface import IMultiselectClient
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
@ -116,7 +117,3 @@ def validate_handshake(handshake_contents: str) -> bool:
# TODO: Modify this when format used by go repo for messages # TODO: Modify this when format used by go repo for messages
# is added # is added
return handshake_contents == MULTISELECT_PROTOCOL_ID return handshake_contents == MULTISELECT_PROTOCOL_ID
class MultiselectClientError(ValueError):
"""Raised when an error occurs in protocol selection process"""

View File

@ -19,7 +19,7 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0")
class InsecureSession(BaseSession): class InsecureSession(BaseSession):
async def run_handshake(self): async def run_handshake(self) -> None:
msg = make_exchange_message(self.local_private_key.get_public_key()) msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString() msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)

View File

@ -4,15 +4,11 @@ from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.multiselect import Multiselect, MultiselectError from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import ( from libp2p.protocol_muxer.multiselect_client import MultiselectClient
MultiselectClient,
MultiselectClientError,
)
from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.transport.exceptions import HandshakeFailure, SecurityUpgradeFailure
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -67,18 +63,8 @@ class SecurityMultistream(ABC):
for an inbound connection (i.e. we are not the initiator) for an inbound connection (i.e. we are not the initiator)
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
try:
transport = await self.select_transport(conn, False) transport = await self.select_transport(conn, False)
except MultiselectError as error:
raise SecurityUpgradeFailure(
"failed to negotiate the secure protocol"
) from error
try:
secure_conn = await transport.secure_inbound(conn) secure_conn = await transport.secure_inbound(conn)
except HandshakeFailure as error:
raise SecurityUpgradeFailure(
"failed to secure the inbound transport"
) from error
return secure_conn return secure_conn
async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn:
@ -87,18 +73,8 @@ class SecurityMultistream(ABC):
for an inbound connection (i.e. we are the initiator) for an inbound connection (i.e. we are the initiator)
:return: secure connection object (that implements secure_conn_interface) :return: secure connection object (that implements secure_conn_interface)
""" """
try:
transport = await self.select_transport(conn, True) transport = await self.select_transport(conn, True)
except MultiselectClientError as error:
raise SecurityUpgradeFailure(
"failed to negotiate the secure protocol"
) from error
try:
secure_conn = await transport.secure_outbound(conn, peer_id) secure_conn = await transport.secure_outbound(conn, peer_id)
except HandshakeFailure as error:
raise SecurityUpgradeFailure(
"failed to secure the outbound transport"
) from error
return secure_conn return secure_conn
async def select_transport( async def select_transport(

View File

@ -10,5 +10,9 @@ class SecurityUpgradeFailure(UpgradeFailure):
pass pass
class MuxerUpgradeFailure(UpgradeFailure):
pass
class HandshakeFailure(BaseLibp2pError): class HandshakeFailure(BaseLibp2pError):
pass pass

View File

@ -3,11 +3,17 @@ from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
from libp2p.transport.exceptions import (
HandshakeFailure,
MuxerUpgradeFailure,
SecurityUpgradeFailure,
)
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .listener_interface import IListener from .listener_interface import IListener
@ -39,10 +45,20 @@ class TransportUpgrader:
""" """
Upgrade conn to a secured connection Upgrade conn to a secured connection
""" """
try:
if initiator: if initiator:
return await self.security_multistream.secure_outbound(raw_conn, peer_id) return await self.security_multistream.secure_outbound(
raw_conn, peer_id
)
return await self.security_multistream.secure_inbound(raw_conn) return await self.security_multistream.secure_inbound(raw_conn)
except (MultiselectError, MultiselectClientError) as error:
raise SecurityUpgradeFailure(
"failed to negotiate the secure protocol"
) from error
except HandshakeFailure as error:
raise SecurityUpgradeFailure(
"handshake failed when upgrading to secure connection"
) from error
async def upgrade_connection( async def upgrade_connection(
self, self,
@ -53,6 +69,11 @@ class TransportUpgrader:
""" """
Upgrade secured connection to a muxed connection Upgrade secured connection to a muxed connection
""" """
try:
return await self.muxer_multistream.new_conn( return await self.muxer_multistream.new_conn(
conn, generic_protocol_handler, peer_id conn, generic_protocol_handler, peer_id
) )
except (MultiselectError, MultiselectClientError) as error:
raise MuxerUpgradeFailure(
"failed to negotiate the multiplexer protocol"
) from error

View File

@ -3,7 +3,7 @@ import asyncio
import pytest import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.protocol_muxer.multiselect_client import MultiselectClientError from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.utils import cleanup, set_up_nodes_by_transport_opt
PROTOCOL_ID = "/chat/1.0.0" PROTOCOL_ID = "/chat/1.0.0"

View File

@ -1,6 +1,6 @@
import pytest import pytest
from libp2p.protocol_muxer.multiselect_client import MultiselectClientError from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.utils import cleanup, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different

View File

@ -4,9 +4,9 @@ import pytest
from libp2p import new_node from libp2p import new_node
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.network.exceptions import SwarmException
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.security.insecure.transport import InsecureSession, InsecureTransport
from libp2p.security.simple.transport import SimpleSecurityTransport from libp2p.security.simple.transport import SimpleSecurityTransport
from libp2p.transport.exceptions import SecurityUpgradeFailure
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.utils import cleanup, connect from tests.utils import cleanup, connect
@ -161,7 +161,7 @@ async def test_multiple_security_none_the_same_fails():
def assertion_func(_): def assertion_func(_):
assert False assert False
with pytest.raises(SecurityUpgradeFailure): with pytest.raises(SwarmException):
await perform_simple_test( await perform_simple_test(
assertion_func, transports_for_initiator, transports_for_noninitiator assertion_func, transports_for_initiator, transports_for_noninitiator
) )