diff --git a/libp2p/network/connection/exceptions.py b/libp2p/network/connection/exceptions.py new file mode 100644 index 0000000..ecbf3fa --- /dev/null +++ b/libp2p/network/connection/exceptions.py @@ -0,0 +1,5 @@ +from libp2p.io.exceptions import IOException + + +class RawConnError(IOException): + pass diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index fe09c6f..144c1a8 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,5 +1,6 @@ import asyncio +from .exceptions import RawConnError from .raw_connection_interface import IRawConnection @@ -23,19 +24,33 @@ class RawConnection(IRawConnection): self._drain_lock = asyncio.Lock() async def write(self, data: bytes) -> None: - self.writer.write(data) + """ + Raise `RawConnError` if the underlying connection breaks + """ + try: + self.writer.write(data) + except ConnectionResetError as error: + raise RawConnError(error) # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 # Use a lock to serialize drain() calls. Circumvents this bug: # https://bugs.python.org/issue29930 async with self._drain_lock: - await self.writer.drain() + try: + await self.writer.drain() + except ConnectionResetError: + raise RawConnError() async def read(self, n: int = -1) -> bytes: """ Read up to ``n`` bytes from the underlying stream. This call is delegated directly to the underlying ``self.reader``. + + Raise `RawConnError` if the underlying connection breaks """ - return await self.reader.read(n) + try: + return await self.reader.read(n) + except ConnectionResetError as error: + raise RawConnError(error) async def close(self) -> None: self.writer.close() diff --git a/libp2p/network/stream/exceptions.py b/libp2p/network/stream/exceptions.py index 58f3ddf..7af28ec 100644 --- a/libp2p/network/stream/exceptions.py +++ b/libp2p/network/stream/exceptions.py @@ -1,7 +1,7 @@ -from libp2p.exceptions import BaseLibp2pError +from libp2p.io.exceptions import IOException -class StreamError(BaseLibp2pError): +class StreamError(IOException): pass diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 4d53b80..da32a02 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -7,12 +7,17 @@ from multiaddr import Multiaddr from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore_interface import IPeerStore +from libp2p.protocol_muxer.exceptions import MultiselectClientError from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.routing.interfaces import IPeerRouting from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream -from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure +from libp2p.transport.exceptions import ( + MuxerUpgradeFailure, + OpenConnectionError, + SecurityUpgradeFailure, +) from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader @@ -117,7 +122,13 @@ class Swarm(INetwork): multiaddr = self.router.find_peer(peer_id) # Dial peer (connection to peer does not yet exist) # Transport dials peer (gets back a raw conn) - raw_conn = await self.transport.dial(multiaddr) + try: + raw_conn = await self.transport.dial(multiaddr) + except OpenConnectionError as error: + logger.debug("fail to dial peer %s over base transport", peer_id) + raise SwarmException( + "fail to open connection to peer %s", peer_id + ) from error logger.debug("dialed peer %s over base transport", peer_id) @@ -162,6 +173,7 @@ class Swarm(INetwork): """ :param peer_id: peer_id of destination :param protocol_id: protocol id + :raises SwarmException: raised when an error occurs :return: net stream instance """ logger.debug( @@ -176,9 +188,16 @@ class Swarm(INetwork): muxed_stream = await muxed_conn.open_stream() # Perform protocol muxing to determine protocol to use - selected_protocol = await self.multiselect_client.select_one_of( - list(protocol_ids), MultiselectCommunicator(muxed_stream) - ) + try: + selected_protocol = await self.multiselect_client.select_one_of( + list(protocol_ids), MultiselectCommunicator(muxed_stream) + ) + except MultiselectClientError as error: + logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) + await muxed_stream.reset() + raise SwarmException( + "failt to open a stream to peer %s", peer_id + ) from error # Create a net stream with the selected protocol net_stream = NetStream(muxed_stream) diff --git a/libp2p/protocol_muxer/exceptions.py b/libp2p/protocol_muxer/exceptions.py index cf47aca..a34e318 100644 --- a/libp2p/protocol_muxer/exceptions.py +++ b/libp2p/protocol_muxer/exceptions.py @@ -1,6 +1,10 @@ from libp2p.exceptions import BaseLibp2pError +class MultiselectCommunicatorError(BaseLibp2pError): + """Raised when an error occurs during read/write via communicator""" + + class MultiselectError(BaseLibp2pError): """Raised when an error occurs in multiselect process""" diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 0c3dc72..a0fa91f 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -2,7 +2,7 @@ from typing import Dict, Tuple from libp2p.typing import StreamHandlerFn, TProtocol -from .exceptions import MultiselectError +from .exceptions import MultiselectCommunicatorError, MultiselectError from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_muxer_interface import IMultiselectMuxer @@ -37,7 +37,7 @@ class Multiselect(IMultiselectMuxer): Negotiate performs protocol selection :param stream: stream to negotiate on :return: selected protocol name, handler function - :raise Exception: negotiation failed exception + :raise MultiselectError: raised when negotiation failed """ # Perform handshake to ensure multiselect protocol IDs match @@ -46,7 +46,10 @@ class Multiselect(IMultiselectMuxer): # Read and respond to commands until a valid protocol ID is sent while True: # Read message - command = await communicator.read() + try: + command = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Command is ls or a protocol if command == "ls": @@ -56,27 +59,39 @@ class Multiselect(IMultiselectMuxer): protocol = TProtocol(command) if protocol in self.handlers: # Tell counterparty we have decided on a protocol - await communicator.write(protocol) + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Return the decided on protocol return protocol, self.handlers[protocol] # Tell counterparty this protocol was not found - await communicator.write(PROTOCOL_NOT_FOUND_MSG) + try: + await communicator.write(PROTOCOL_NOT_FOUND_MSG) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ Perform handshake to agree on multiselect protocol :param communicator: communicator to use - :raise Exception: error in handshake + :raise MultiselectError: raised when handshake failed """ # TODO: Use format used by go repo for messages # Send our MULTISELECT_PROTOCOL_ID to other party - await communicator.write(MULTISELECT_PROTOCOL_ID) + try: + await communicator.write(MULTISELECT_PROTOCOL_ID) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Read in the protocol ID from other party - handshake_contents = await communicator.read() + try: + handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Confirm that the protocols are the same if not validate_handshake(handshake_contents): diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index fcd55d0..24db70a 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -2,7 +2,7 @@ from typing import Sequence from libp2p.typing import TProtocol -from .exceptions import MultiselectClientError +from .exceptions import MultiselectClientError, MultiselectCommunicatorError from .multiselect_client_interface import IMultiselectClient from .multiselect_communicator_interface import IMultiselectCommunicator @@ -21,16 +21,22 @@ class MultiselectClient(IMultiselectClient): Ensure that the client and multiselect are both using the same multiselect protocol :param stream: stream to communicate with multiselect over - :raise Exception: multiselect protocol ID mismatch + :raise MultiselectClientError: raised when handshake failed """ # TODO: Use format used by go repo for messages # Send our MULTISELECT_PROTOCOL_ID to counterparty - await communicator.write(MULTISELECT_PROTOCOL_ID) + try: + await communicator.write(MULTISELECT_PROTOCOL_ID) + except MultiselectCommunicatorError as error: + raise MultiselectClientError(error) # Read in the protocol ID from other party - handshake_contents = await communicator.read() + try: + handshake_contents = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectClientError(str(error)) # Confirm that the protocols are the same if not validate_handshake(handshake_contents): @@ -48,6 +54,7 @@ class MultiselectClient(IMultiselectClient): :param protocol: protocol to select :param stream: stream to communicate with multiselect over :return: selected protocol + :raise MultiselectClientError: raised when protocol negotiation failed """ # Perform handshake to ensure multiselect protocol IDs match await self.handshake(communicator) @@ -71,15 +78,21 @@ class MultiselectClient(IMultiselectClient): Try to select the given protocol or raise exception if fails :param communicator: communicator to use to communicate with counterparty :param protocol: protocol to select - :raise Exception: error in protocol selection + :raise MultiselectClientError: raised when protocol negotiation failed :return: selected protocol """ # Tell counterparty we want to use protocol - await communicator.write(protocol) + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectClientError(error) # Get what counterparty says in response - response = await communicator.read() + try: + response = await communicator.read() + except MultiselectCommunicatorError as error: + raise MultiselectClientError(str(error)) # Return protocol if response is equal to protocol or raise error if response == protocol: diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index 59252c5..a66a564 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -1,6 +1,9 @@ +from libp2p.exceptions import ParseError from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException from libp2p.utils import encode_delim, read_delim +from .exceptions import MultiselectCommunicatorError from .multiselect_communicator_interface import IMultiselectCommunicator @@ -11,9 +14,26 @@ class MultiselectCommunicator(IMultiselectCommunicator): self.read_writer = read_writer async def write(self, msg_str: str) -> None: + """ + :raise MultiselectCommunicatorError: raised when failed to write to underlying reader + """ msg_bytes = encode_delim(msg_str.encode()) - await self.read_writer.write(msg_bytes) + try: + await self.read_writer.write(msg_bytes) + except IOException: + raise MultiselectCommunicatorError( + "fail to write to multiselect communicator" + ) async def read(self) -> str: - data = await read_delim(self.read_writer) + """ + :raise MultiselectCommunicatorError: raised when failed to read from underlying reader + """ + try: + data = await read_delim(self.read_writer) + # `IOException` includes `IncompleteReadError` and `StreamError` + except (ParseError, IOException, ValueError): + raise MultiselectCommunicatorError( + "fail to read from multiselect communicator" + ) return data.decode() diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b162b89..e413b28 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -16,8 +16,11 @@ from typing import ( import base58 from lru import LRU -from libp2p.exceptions import ValidationError +from libp2p.exceptions import ParseError, ValidationError from libp2p.host.host_interface import IHost +from libp2p.io.exceptions import IncompleteReadError +from libp2p.network.exceptions import SwarmException +from libp2p.network.stream.exceptions import StreamEOF, StreamReset from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.typing import TProtocol @@ -154,7 +157,13 @@ class Pubsub: peer_id = stream.mplex_conn.peer_id while True: - incoming: bytes = await read_varint_prefixed_bytes(stream) + try: + incoming: bytes = await read_varint_prefixed_bytes(stream) + except (ParseError, IncompleteReadError) as error: + logger.debug( + "read corrupted data from peer %s, error=%s", peer_id, error + ) + continue rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) if rpc_incoming.publish: @@ -228,10 +237,20 @@ class Pubsub: on one of the supported pubsub protocols. :param stream: newly created stream """ - await self.continuously_read_stream(stream) + try: + await self.continuously_read_stream(stream) + except (StreamEOF, StreamReset) as error: + logger.debug("fail to read from stream, error=%s", error) + stream.reset() + # TODO: what to do when the stream is terminated? + # disconnect the peer? async def _handle_new_peer(self, peer_id: ID) -> None: - stream: INetStream = await self.host.new_stream(peer_id, self.protocols) + try: + stream: INetStream = await self.host.new_stream(peer_id, self.protocols) + except SwarmException as error: + logger.debug("fail to add new peer %s, error %s", peer_id, error) + return self.peers[peer_id] = stream diff --git a/libp2p/security/exceptions.py b/libp2p/security/exceptions.py new file mode 100644 index 0000000..269b2cb --- /dev/null +++ b/libp2p/security/exceptions.py @@ -0,0 +1,5 @@ +from libp2p.exceptions import BaseLibp2pError + + +class HandshakeFailure(BaseLibp2pError): + pass diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 27efc86..7df0575 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -4,12 +4,13 @@ from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.utils import pubkey_from_protobuf from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.exceptions import RawConnError from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession from libp2p.security.base_transport import BaseSecureTransport +from libp2p.security.exceptions import HandshakeFailure from libp2p.security.secure_conn_interface import ISecureConn -from libp2p.transport.exceptions import HandshakeFailure from libp2p.typing import TProtocol from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed @@ -44,12 +45,21 @@ class InsecureSession(BaseSession): await self.conn.close() async def run_handshake(self) -> None: + """ + Raise `HandshakeFailure` when handshake failed + """ msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) - await self.write(encoded_msg_bytes) + try: + await self.write(encoded_msg_bytes) + except RawConnError: + raise HandshakeFailure("connection closed") - remote_msg_bytes = await read_fixedint_prefixed(self.conn) + try: + remote_msg_bytes = await read_fixedint_prefixed(self.conn) + except RawConnError: + raise HandshakeFailure("connection closed") remote_msg = plaintext_pb2.Exchange() remote_msg.ParseFromString(remote_msg_bytes) received_peer_id = ID(remote_msg.id) diff --git a/libp2p/security/secio/exceptions.py b/libp2p/security/secio/exceptions.py index f9ea8cf..c03fda4 100644 --- a/libp2p/security/secio/exceptions.py +++ b/libp2p/security/secio/exceptions.py @@ -1,4 +1,7 @@ -class SecioException(Exception): +from libp2p.security.exceptions import HandshakeFailure + + +class SecioException(HandshakeFailure): pass @@ -19,9 +22,9 @@ class InvalidSignatureOnExchange(SecioException): pass -class HandshakeFailed(SecioException): - pass - - class IncompatibleChoices(SecioException): pass + + +class InconsistentNonce(SecioException): + pass diff --git a/libp2p/security/secio/transport.py b/libp2p/security/secio/transport.py index e223a94..e1aa022 100644 --- a/libp2p/security/secio/transport.py +++ b/libp2p/security/secio/transport.py @@ -16,6 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.serialization import deserialize_public_key +from libp2p.io.exceptions import IOException from libp2p.io.msgio import MsgIOReadWriter from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID as PeerID @@ -24,8 +25,8 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.secure_conn_interface import ISecureConn from .exceptions import ( - HandshakeFailed, IncompatibleChoices, + InconsistentNonce, InvalidSignatureOnExchange, PeerMismatchException, SecioException, @@ -399,6 +400,8 @@ async def create_secure_session( Attempt the initial `secio` handshake with the remote peer. If successful, return an object that provides secure communication to the ``remote_peer``. + Raise `SecioException` when `conn` closed. + Raise `InconsistentNonce` when handshake failed """ msg_io = MsgIOReadWriter(conn) try: @@ -408,14 +411,21 @@ async def create_secure_session( except SecioException as e: await conn.close() raise e + # `IOException` includes errors raised while read from/write to raw connection + except IOException: + raise SecioException("connection closed") initiator = remote_peer is not None session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator) - received_nonce = await _finish_handshake(session, remote_nonce) + try: + received_nonce = await _finish_handshake(session, remote_nonce) + # `IOException` includes errors raised while read from/write to raw connection + except IOException: + raise SecioException("connection closed") if received_nonce != local_nonce: await conn.close() - raise HandshakeFailed() + raise InconsistentNonce() return session diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index c75000d..7a5323d 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -2,8 +2,11 @@ import asyncio from typing import Any # noqa: F401 from typing import Dict, List, Optional, Tuple +from libp2p.exceptions import ParseError +from libp2p.io.exceptions import IncompleteReadError from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID +from libp2p.protocol_muxer.exceptions import MultiselectError from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.typing import TProtocol @@ -125,7 +128,13 @@ class Mplex(IMuxedConn): """ stream = await self._initialize_stream(stream_id, name) # Perform protocol negotiation for the stream. - self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) + try: + await self.generic_protocol_handler(stream) + except MultiselectError: + # Un-register and reset the stream + del self.streams[stream_id] + await stream.reset() + return async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -178,7 +187,11 @@ class Mplex(IMuxedConn): # `NewStream` for the same id is received twice... # TODO: Shutdown pass - await self.accept_stream(stream_id, message.decode()) + self._tasks.append( + asyncio.ensure_future( + self.accept_stream(stream_id, message.decode()) + ) + ) elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, @@ -248,13 +261,15 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. # Timeout is set to a relatively small value to alleviate wait time to exit # loop in handle_incoming - header = await decode_uvarint_from_stream(self.secured_conn) - # TODO: Handle the case of EOF and other exceptions? + try: + header = await decode_uvarint_from_stream(self.secured_conn) + except ParseError: + return None, None, None try: message = await asyncio.wait_for( read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) - except asyncio.TimeoutError: + except (ParseError, IncompleteReadError, asyncio.TimeoutError): # TODO: Investigate what we should do if time is out. return None, None, None diff --git a/libp2p/transport/exceptions.py b/libp2p/transport/exceptions.py index b10cfc9..d935b3a 100644 --- a/libp2p/transport/exceptions.py +++ b/libp2p/transport/exceptions.py @@ -1,7 +1,10 @@ from libp2p.exceptions import BaseLibp2pError -# TODO: Add `BaseLibp2pError` and `UpgradeFailure` can inherit from it? +class OpenConnectionError(BaseLibp2pError): + pass + + class UpgradeFailure(BaseLibp2pError): pass @@ -12,7 +15,3 @@ class SecurityUpgradeFailure(UpgradeFailure): class MuxerUpgradeFailure(UpgradeFailure): pass - - -class HandshakeFailure(BaseLibp2pError): - pass diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index a63dbd0..5ee2428 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -6,6 +6,7 @@ from multiaddr import Multiaddr from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler @@ -62,11 +63,15 @@ class TCP(ITransport): dial a transport to peer listening on multiaddr :param maddr: multiaddr of peer :return: `RawConnection` if successful + :raise OpenConnectionError: raised when failed to open connection """ self.host = maddr.value_for_protocol("ip4") self.port = int(maddr.value_for_protocol("tcp")) - reader, writer = await asyncio.open_connection(self.host, self.port) + try: + reader, writer = await asyncio.open_connection(self.host, self.port) + except (ConnectionAbortedError, ConnectionRefusedError) as error: + raise OpenConnectionError(error) return RawConnection(reader, writer, True) diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 762a811..96234c6 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -4,16 +4,13 @@ from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError +from libp2p.security.exceptions import HandshakeFailure from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.security_multistream import SecurityMultistream from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream -from libp2p.transport.exceptions import ( - HandshakeFailure, - MuxerUpgradeFailure, - SecurityUpgradeFailure, -) +from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.typing import TProtocol from .listener_interface import IListener diff --git a/libp2p/utils.py b/libp2p/utils.py index c69f61b..8362a5a 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -65,10 +65,6 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes: async def read_varint_prefixed_bytes(reader: Reader) -> bytes: len_msg = await decode_uvarint_from_stream(reader) data = await read_exactly(reader, len_msg) - if len(data) != len_msg: - raise ValueError( - f"failed to read enough bytes: len_msg={len_msg}, data={data!r}" - ) return data diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index 18a172c..e2aa71d 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -2,8 +2,8 @@ import asyncio import pytest +from libp2p.network.exceptions import SwarmException from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.protocol_muxer.exceptions import MultiselectClientError from tests.utils import set_up_nodes_by_transport_opt PROTOCOL_ID = "/chat/1.0.0" @@ -84,7 +84,7 @@ async def no_common_protocol(host_a, host_b): host_a.set_stream_handler(PROTOCOL_ID, stream_handler) # try to creates a new new with a procotol not known by the other host - with pytest.raises(MultiselectClientError): + with pytest.raises(SwarmException): await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index d7523ac..4e58e5b 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,6 +1,6 @@ import pytest -from libp2p.protocol_muxer.exceptions import MultiselectClientError +from libp2p.network.exceptions import SwarmException from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different @@ -47,7 +47,7 @@ async def test_single_protocol_succeeds(): @pytest.mark.asyncio async def test_single_protocol_fails(): - with pytest.raises(MultiselectClientError): + with pytest.raises(SwarmException): await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) # Cleanup not reached on error @@ -77,7 +77,7 @@ async def test_multiple_protocol_second_is_valid_succeeds(): async def test_multiple_protocol_fails(): protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] - with pytest.raises(MultiselectClientError): + with pytest.raises(SwarmException): await perform_simple_test("", protocols_for_client, protocols_for_listener) # Cleanup not reached on error