diff --git a/libp2p/network/exceptions.py b/libp2p/network/exceptions.py new file mode 100644 index 0000000..92be9b8 --- /dev/null +++ b/libp2p/network/exceptions.py @@ -0,0 +1,5 @@ +from libp2p.exceptions import BaseLibp2pError + + +class SwarmException(BaseLibp2pError): + pass diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 83c3b20..d9cdf48 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -33,7 +33,7 @@ class INetwork(ABC): dial_peer try to create a connection to peer_id :param peer_id: peer if we want to dial - :raises SwarmException: raised when no address if found for peer_id + :raises SwarmException: raised when an error occurs :return: muxed connection """ diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 19d1b76..3ddc4ab 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -10,13 +10,14 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator from libp2p.routing.interfaces import IPeerRouting from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream -from libp2p.transport.exceptions import UpgradeFailure +from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import StreamHandlerFn, TProtocol from .connection.raw_connection import RawConnection +from .exceptions import SwarmException from .network_interface import INetwork from .notifee_interface import INotifee from .stream.net_stream import NetStream @@ -85,7 +86,7 @@ class Swarm(INetwork): """ dial_peer try to create a connection to peer_id :param peer_id: peer if we want to dial - :raises SwarmException: raised when no address if found for peer_id + :raises SwarmException: raised when an error occurs :return: muxed connection """ @@ -111,10 +112,26 @@ class Swarm(INetwork): # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn - secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True) - muxed_conn = await self.upgrader.upgrade_connection( - secured_conn, self.generic_protocol_handler, peer_id - ) + try: + secured_conn = await self.upgrader.upgrade_security( + raw_conn, peer_id, True + ) + except SecurityUpgradeFailure as error: + # TODO: Add logging to indicate the failure + raw_conn.close() + raise SwarmException( + f"fail to upgrade the connection to a secured connection from {peer_id}" + ) from error + try: + muxed_conn = await self.upgrader.upgrade_connection( + secured_conn, self.generic_protocol_handler, peer_id + ) + except MuxerUpgradeFailure as error: + # TODO: Add logging to indicate the failure + secured_conn.close() + raise SwarmException( + f"fail to upgrade the connection to a muxed connection from {peer_id}" + ) from error # Store muxed connection in connections self.connections[peer_id] = muxed_conn @@ -197,19 +214,28 @@ class Swarm(INetwork): # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure # the conn and then mux the conn - # FIXME: This dummy `ID(b"")` for the remote peer is useless. try: + # FIXME: This dummy `ID(b"")` for the remote peer is useless. secured_conn = await self.upgrader.upgrade_security( raw_conn, ID(b""), False ) - peer_id = secured_conn.get_remote_peer() + except SecurityUpgradeFailure as error: + # TODO: Add logging to indicate the failure + raw_conn.close() + raise SwarmException( + "fail to upgrade the connection to a secured connection" + ) from error + peer_id = secured_conn.get_remote_peer() + try: muxed_conn = await self.upgrader.upgrade_connection( secured_conn, self.generic_protocol_handler, peer_id ) - except UpgradeFailure: + except MuxerUpgradeFailure as error: # TODO: Add logging to indicate the failure - raw_conn.close() - return + secured_conn.close() + raise SwarmException( + f"fail to upgrade the connection to a muxed connection from {peer_id}" + ) from error # Store muxed_conn with peer id self.connections[peer_id] = muxed_conn @@ -283,7 +309,3 @@ def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: asyncio.ensure_future(handler(net_stream)) return generic_protocol_handler - - -class SwarmException(Exception): - pass diff --git a/libp2p/protocol_muxer/exceptions.py b/libp2p/protocol_muxer/exceptions.py new file mode 100644 index 0000000..cf47aca --- /dev/null +++ b/libp2p/protocol_muxer/exceptions.py @@ -0,0 +1,9 @@ +from libp2p.exceptions import BaseLibp2pError + + +class MultiselectError(BaseLibp2pError): + """Raised when an error occurs in multiselect process""" + + +class MultiselectClientError(BaseLibp2pError): + """Raised when an error occurs in protocol selection process""" diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 9fd3de8..0c3dc72 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -2,6 +2,7 @@ from typing import Dict, Tuple from libp2p.typing import StreamHandlerFn, TProtocol +from .exceptions import MultiselectError from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_muxer_interface import IMultiselectMuxer @@ -97,7 +98,3 @@ def validate_handshake(handshake_contents: str) -> bool: # TODO: Modify this when format used by go repo for messages # is added return handshake_contents == MULTISELECT_PROTOCOL_ID - - -class MultiselectError(ValueError): - """Raised when an error occurs in multiselect process""" diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index 062aedc..5fcfc45 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -2,6 +2,7 @@ from typing import Sequence from libp2p.typing import TProtocol +from .exceptions import MultiselectClientError from .multiselect_client_interface import IMultiselectClient from .multiselect_communicator_interface import IMultiselectCommunicator @@ -116,7 +117,3 @@ def validate_handshake(handshake_contents: str) -> bool: # TODO: Modify this when format used by go repo for messages # is added return handshake_contents == MULTISELECT_PROTOCOL_ID - - -class MultiselectClientError(ValueError): - """Raised when an error occurs in protocol selection process""" diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 8ce6f41..6cb882a 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -19,7 +19,7 @@ PLAINTEXT_PROTOCOL_ID = TProtocol("/plaintext/2.0.0") class InsecureSession(BaseSession): - async def run_handshake(self): + async def run_handshake(self) -> None: msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index f52e54a..6e69d7a 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -4,15 +4,11 @@ from typing import Mapping from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID -from libp2p.protocol_muxer.multiselect import Multiselect, MultiselectError -from libp2p.protocol_muxer.multiselect_client import ( - MultiselectClient, - MultiselectClientError, -) +from libp2p.protocol_muxer.multiselect import Multiselect +from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport -from libp2p.transport.exceptions import HandshakeFailure, SecurityUpgradeFailure from libp2p.typing import TProtocol @@ -67,18 +63,8 @@ class SecurityMultistream(ABC): for an inbound connection (i.e. we are not the initiator) :return: secure connection object (that implements secure_conn_interface) """ - try: - transport = await self.select_transport(conn, False) - except MultiselectError as error: - raise SecurityUpgradeFailure( - "failed to negotiate the secure protocol" - ) from error - try: - secure_conn = await transport.secure_inbound(conn) - except HandshakeFailure as error: - raise SecurityUpgradeFailure( - "failed to secure the inbound transport" - ) from error + transport = await self.select_transport(conn, False) + secure_conn = await transport.secure_inbound(conn) return secure_conn async def secure_outbound(self, conn: IRawConnection, peer_id: ID) -> ISecureConn: @@ -87,18 +73,8 @@ class SecurityMultistream(ABC): for an inbound connection (i.e. we are the initiator) :return: secure connection object (that implements secure_conn_interface) """ - try: - transport = await self.select_transport(conn, True) - except MultiselectClientError as error: - raise SecurityUpgradeFailure( - "failed to negotiate the secure protocol" - ) from error - try: - secure_conn = await transport.secure_outbound(conn, peer_id) - except HandshakeFailure as error: - raise SecurityUpgradeFailure( - "failed to secure the outbound transport" - ) from error + transport = await self.select_transport(conn, True) + secure_conn = await transport.secure_outbound(conn, peer_id) return secure_conn async def select_transport( diff --git a/libp2p/transport/exceptions.py b/libp2p/transport/exceptions.py index 2a85bec..b10cfc9 100644 --- a/libp2p/transport/exceptions.py +++ b/libp2p/transport/exceptions.py @@ -10,5 +10,9 @@ class SecurityUpgradeFailure(UpgradeFailure): pass +class MuxerUpgradeFailure(UpgradeFailure): + pass + + class HandshakeFailure(BaseLibp2pError): pass diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index b0373ec..762a811 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -3,11 +3,17 @@ from typing import Mapping from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID +from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.security_multistream import SecurityMultistream from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream +from libp2p.transport.exceptions import ( + HandshakeFailure, + MuxerUpgradeFailure, + SecurityUpgradeFailure, +) from libp2p.typing import TProtocol from .listener_interface import IListener @@ -39,10 +45,20 @@ class TransportUpgrader: """ Upgrade conn to a secured connection """ - if initiator: - return await self.security_multistream.secure_outbound(raw_conn, peer_id) - - return await self.security_multistream.secure_inbound(raw_conn) + try: + if initiator: + return await self.security_multistream.secure_outbound( + raw_conn, peer_id + ) + return await self.security_multistream.secure_inbound(raw_conn) + except (MultiselectError, MultiselectClientError) as error: + raise SecurityUpgradeFailure( + "failed to negotiate the secure protocol" + ) from error + except HandshakeFailure as error: + raise SecurityUpgradeFailure( + "handshake failed when upgrading to secure connection" + ) from error async def upgrade_connection( self, @@ -53,6 +69,11 @@ class TransportUpgrader: """ Upgrade secured connection to a muxed connection """ - return await self.muxer_multistream.new_conn( - conn, generic_protocol_handler, peer_id - ) + try: + return await self.muxer_multistream.new_conn( + conn, generic_protocol_handler, peer_id + ) + except (MultiselectError, MultiselectClientError) as error: + raise MuxerUpgradeFailure( + "failed to negotiate the multiplexer protocol" + ) from error diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index 0422c95..f461d9d 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -3,7 +3,7 @@ import asyncio import pytest from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.protocol_muxer.multiselect_client import MultiselectClientError +from libp2p.protocol_muxer.exceptions import MultiselectClientError from tests.utils import cleanup, set_up_nodes_by_transport_opt PROTOCOL_ID = "/chat/1.0.0" diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 775c460..02f08bd 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,6 +1,6 @@ import pytest -from libp2p.protocol_muxer.multiselect_client import MultiselectClientError +from libp2p.protocol_muxer.exceptions import MultiselectClientError from tests.utils import cleanup, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index ef32b3b..ea78d1f 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -4,9 +4,9 @@ import pytest from libp2p import new_node from libp2p.crypto.rsa import create_new_key_pair +from libp2p.network.exceptions import SwarmException from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.security.simple.transport import SimpleSecurityTransport -from libp2p.transport.exceptions import SecurityUpgradeFailure from tests.configs import LISTEN_MADDR from tests.utils import cleanup, connect @@ -161,7 +161,7 @@ async def test_multiple_security_none_the_same_fails(): def assertion_func(_): assert False - with pytest.raises(SecurityUpgradeFailure): + with pytest.raises(SwarmException): await perform_simple_test( assertion_func, transports_for_initiator, transports_for_noninitiator )