Merge pull request #299 from NIC619/add_more_error_handling

Add more error handling
This commit is contained in:
NIC Lin 2019-09-19 23:45:02 +08:00 committed by GitHub
commit 85457fa308
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 217 additions and 67 deletions

View File

@ -0,0 +1,5 @@
from libp2p.io.exceptions import IOException
class RawConnError(IOException):
pass

View File

@ -1,5 +1,6 @@
import asyncio
from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection
@ -23,19 +24,33 @@ class RawConnection(IRawConnection):
self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None:
self.writer.write(data)
"""
Raise `RawConnError` if the underlying connection breaks
"""
try:
self.writer.write(data)
except ConnectionResetError as error:
raise RawConnError(error)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
await self.writer.drain()
try:
await self.writer.drain()
except ConnectionResetError:
raise RawConnError()
async def read(self, n: int = -1) -> bytes:
"""
Read up to ``n`` bytes from the underlying stream.
This call is delegated directly to the underlying ``self.reader``.
Raise `RawConnError` if the underlying connection breaks
"""
return await self.reader.read(n)
try:
return await self.reader.read(n)
except ConnectionResetError as error:
raise RawConnError(error)
async def close(self) -> None:
self.writer.close()

View File

@ -1,7 +1,7 @@
from libp2p.exceptions import BaseLibp2pError
from libp2p.io.exceptions import IOException
class StreamError(BaseLibp2pError):
class StreamError(IOException):
pass

View File

@ -7,12 +7,17 @@ from multiaddr import Multiaddr
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError
from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.exceptions import MultiselectClientError
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
from libp2p.transport.exceptions import (
MuxerUpgradeFailure,
OpenConnectionError,
SecurityUpgradeFailure,
)
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.upgrader import TransportUpgrader
@ -117,7 +122,13 @@ class Swarm(INetwork):
multiaddr = self.router.find_peer(peer_id)
# Dial peer (connection to peer does not yet exist)
# Transport dials peer (gets back a raw conn)
raw_conn = await self.transport.dial(multiaddr)
try:
raw_conn = await self.transport.dial(multiaddr)
except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id)
raise SwarmException(
"fail to open connection to peer %s", peer_id
) from error
logger.debug("dialed peer %s over base transport", peer_id)
@ -162,6 +173,7 @@ class Swarm(INetwork):
"""
:param peer_id: peer_id of destination
:param protocol_id: protocol id
:raises SwarmException: raised when an error occurs
:return: net stream instance
"""
logger.debug(
@ -176,9 +188,16 @@ class Swarm(INetwork):
muxed_stream = await muxed_conn.open_stream()
# Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(muxed_stream)
)
try:
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(muxed_stream)
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await muxed_stream.reset()
raise SwarmException(
"failt to open a stream to peer %s", peer_id
) from error
# Create a net stream with the selected protocol
net_stream = NetStream(muxed_stream)

View File

@ -1,6 +1,10 @@
from libp2p.exceptions import BaseLibp2pError
class MultiselectCommunicatorError(BaseLibp2pError):
"""Raised when an error occurs during read/write via communicator"""
class MultiselectError(BaseLibp2pError):
"""Raised when an error occurs in multiselect process"""

View File

@ -2,7 +2,7 @@ from typing import Dict, Tuple
from libp2p.typing import StreamHandlerFn, TProtocol
from .exceptions import MultiselectError
from .exceptions import MultiselectCommunicatorError, MultiselectError
from .multiselect_communicator_interface import IMultiselectCommunicator
from .multiselect_muxer_interface import IMultiselectMuxer
@ -37,7 +37,7 @@ class Multiselect(IMultiselectMuxer):
Negotiate performs protocol selection
:param stream: stream to negotiate on
:return: selected protocol name, handler function
:raise Exception: negotiation failed exception
:raise MultiselectError: raised when negotiation failed
"""
# Perform handshake to ensure multiselect protocol IDs match
@ -46,7 +46,10 @@ class Multiselect(IMultiselectMuxer):
# Read and respond to commands until a valid protocol ID is sent
while True:
# Read message
command = await communicator.read()
try:
command = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Command is ls or a protocol
if command == "ls":
@ -56,27 +59,39 @@ class Multiselect(IMultiselectMuxer):
protocol = TProtocol(command)
if protocol in self.handlers:
# Tell counterparty we have decided on a protocol
await communicator.write(protocol)
try:
await communicator.write(protocol)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Return the decided on protocol
return protocol, self.handlers[protocol]
# Tell counterparty this protocol was not found
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
"""
Perform handshake to agree on multiselect protocol
:param communicator: communicator to use
:raise Exception: error in handshake
:raise MultiselectError: raised when handshake failed
"""
# TODO: Use format used by go repo for messages
# Send our MULTISELECT_PROTOCOL_ID to other party
await communicator.write(MULTISELECT_PROTOCOL_ID)
try:
await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Read in the protocol ID from other party
handshake_contents = await communicator.read()
try:
handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Confirm that the protocols are the same
if not validate_handshake(handshake_contents):

View File

@ -2,7 +2,7 @@ from typing import Sequence
from libp2p.typing import TProtocol
from .exceptions import MultiselectClientError
from .exceptions import MultiselectClientError, MultiselectCommunicatorError
from .multiselect_client_interface import IMultiselectClient
from .multiselect_communicator_interface import IMultiselectCommunicator
@ -21,16 +21,22 @@ class MultiselectClient(IMultiselectClient):
Ensure that the client and multiselect
are both using the same multiselect protocol
:param stream: stream to communicate with multiselect over
:raise Exception: multiselect protocol ID mismatch
:raise MultiselectClientError: raised when handshake failed
"""
# TODO: Use format used by go repo for messages
# Send our MULTISELECT_PROTOCOL_ID to counterparty
await communicator.write(MULTISELECT_PROTOCOL_ID)
try:
await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error:
raise MultiselectClientError(error)
# Read in the protocol ID from other party
handshake_contents = await communicator.read()
try:
handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error))
# Confirm that the protocols are the same
if not validate_handshake(handshake_contents):
@ -48,6 +54,7 @@ class MultiselectClient(IMultiselectClient):
:param protocol: protocol to select
:param stream: stream to communicate with multiselect over
:return: selected protocol
:raise MultiselectClientError: raised when protocol negotiation failed
"""
# Perform handshake to ensure multiselect protocol IDs match
await self.handshake(communicator)
@ -71,15 +78,21 @@ class MultiselectClient(IMultiselectClient):
Try to select the given protocol or raise exception if fails
:param communicator: communicator to use to communicate with counterparty
:param protocol: protocol to select
:raise Exception: error in protocol selection
:raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol
"""
# Tell counterparty we want to use protocol
await communicator.write(protocol)
try:
await communicator.write(protocol)
except MultiselectCommunicatorError as error:
raise MultiselectClientError(error)
# Get what counterparty says in response
response = await communicator.read()
try:
response = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error))
# Return protocol if response is equal to protocol or raise error
if response == protocol:

View File

@ -1,6 +1,9 @@
from libp2p.exceptions import ParseError
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
from libp2p.utils import encode_delim, read_delim
from .exceptions import MultiselectCommunicatorError
from .multiselect_communicator_interface import IMultiselectCommunicator
@ -11,9 +14,26 @@ class MultiselectCommunicator(IMultiselectCommunicator):
self.read_writer = read_writer
async def write(self, msg_str: str) -> None:
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
"""
msg_bytes = encode_delim(msg_str.encode())
await self.read_writer.write(msg_bytes)
try:
await self.read_writer.write(msg_bytes)
except IOException:
raise MultiselectCommunicatorError(
"fail to write to multiselect communicator"
)
async def read(self) -> str:
data = await read_delim(self.read_writer)
"""
:raise MultiselectCommunicatorError: raised when failed to read from underlying reader
"""
try:
data = await read_delim(self.read_writer)
# `IOException` includes `IncompleteReadError` and `StreamError`
except (ParseError, IOException, ValueError):
raise MultiselectCommunicatorError(
"fail to read from multiselect communicator"
)
return data.decode()

View File

@ -16,8 +16,11 @@ from typing import (
import base58
from lru import LRU
from libp2p.exceptions import ValidationError
from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost
from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.exceptions import SwarmException
from libp2p.network.stream.exceptions import StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
@ -154,7 +157,13 @@ class Pubsub:
peer_id = stream.mplex_conn.peer_id
while True:
incoming: bytes = await read_varint_prefixed_bytes(stream)
try:
incoming: bytes = await read_varint_prefixed_bytes(stream)
except (ParseError, IncompleteReadError) as error:
logger.debug(
"read corrupted data from peer %s, error=%s", peer_id, error
)
continue
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish:
@ -228,10 +237,20 @@ class Pubsub:
on one of the supported pubsub protocols.
:param stream: newly created stream
"""
await self.continuously_read_stream(stream)
try:
await self.continuously_read_stream(stream)
except (StreamEOF, StreamReset) as error:
logger.debug("fail to read from stream, error=%s", error)
stream.reset()
# TODO: what to do when the stream is terminated?
# disconnect the peer?
async def _handle_new_peer(self, peer_id: ID) -> None:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
except SwarmException as error:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
return
self.peers[peer_id] = stream

View File

@ -0,0 +1,5 @@
from libp2p.exceptions import BaseLibp2pError
class HandshakeFailure(BaseLibp2pError):
pass

View File

@ -4,12 +4,13 @@ from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.pb import crypto_pb2
from libp2p.crypto.utils import pubkey_from_protobuf
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.exceptions import RawConnError
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.security.base_session import BaseSession
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.exceptions import HandshakeFailure
from libp2p.typing import TProtocol
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
@ -44,12 +45,21 @@ class InsecureSession(BaseSession):
await self.conn.close()
async def run_handshake(self) -> None:
"""
Raise `HandshakeFailure` when handshake failed
"""
msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
await self.write(encoded_msg_bytes)
try:
await self.write(encoded_msg_bytes)
except RawConnError:
raise HandshakeFailure("connection closed")
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
try:
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
except RawConnError:
raise HandshakeFailure("connection closed")
remote_msg = plaintext_pb2.Exchange()
remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id)

View File

@ -1,4 +1,7 @@
class SecioException(Exception):
from libp2p.security.exceptions import HandshakeFailure
class SecioException(HandshakeFailure):
pass
@ -19,9 +22,9 @@ class InvalidSignatureOnExchange(SecioException):
pass
class HandshakeFailed(SecioException):
pass
class IncompatibleChoices(SecioException):
pass
class InconsistentNonce(SecioException):
pass

View File

@ -16,6 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey
from libp2p.crypto.key_exchange import create_ephemeral_key_pair
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.exceptions import IOException
from libp2p.io.msgio import MsgIOReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID
@ -24,8 +25,8 @@ from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.secure_conn_interface import ISecureConn
from .exceptions import (
HandshakeFailed,
IncompatibleChoices,
InconsistentNonce,
InvalidSignatureOnExchange,
PeerMismatchException,
SecioException,
@ -399,6 +400,8 @@ async def create_secure_session(
Attempt the initial `secio` handshake with the remote peer.
If successful, return an object that provides secure communication to the
``remote_peer``.
Raise `SecioException` when `conn` closed.
Raise `InconsistentNonce` when handshake failed
"""
msg_io = MsgIOReadWriter(conn)
try:
@ -408,14 +411,21 @@ async def create_secure_session(
except SecioException as e:
await conn.close()
raise e
# `IOException` includes errors raised while read from/write to raw connection
except IOException:
raise SecioException("connection closed")
initiator = remote_peer is not None
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
received_nonce = await _finish_handshake(session, remote_nonce)
try:
received_nonce = await _finish_handshake(session, remote_nonce)
# `IOException` includes errors raised while read from/write to raw connection
except IOException:
raise SecioException("connection closed")
if received_nonce != local_nonce:
await conn.close()
raise HandshakeFailed()
raise InconsistentNonce()
return session

View File

@ -2,8 +2,11 @@ import asyncio
from typing import Any # noqa: F401
from typing import Dict, List, Optional, Tuple
from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.typing import TProtocol
@ -125,7 +128,13 @@ class Mplex(IMuxedConn):
"""
stream = await self._initialize_stream(stream_id, name)
# Perform protocol negotiation for the stream.
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
try:
await self.generic_protocol_handler(stream)
except MultiselectError:
# Un-register and reset the stream
del self.streams[stream_id]
await stream.reset()
return
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -178,7 +187,11 @@ class Mplex(IMuxedConn):
# `NewStream` for the same id is received twice...
# TODO: Shutdown
pass
await self.accept_stream(stream_id, message.decode())
self._tasks.append(
asyncio.ensure_future(
self.accept_stream(stream_id, message.decode())
)
)
elif flag in (
HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
@ -248,13 +261,15 @@ class Mplex(IMuxedConn):
# FIXME: No timeout is used in Go implementation.
# Timeout is set to a relatively small value to alleviate wait time to exit
# loop in handle_incoming
header = await decode_uvarint_from_stream(self.secured_conn)
# TODO: Handle the case of EOF and other exceptions?
try:
header = await decode_uvarint_from_stream(self.secured_conn)
except ParseError:
return None, None, None
try:
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except asyncio.TimeoutError:
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
# TODO: Investigate what we should do if time is out.
return None, None, None

View File

@ -1,7 +1,10 @@
from libp2p.exceptions import BaseLibp2pError
# TODO: Add `BaseLibp2pError` and `UpgradeFailure` can inherit from it?
class OpenConnectionError(BaseLibp2pError):
pass
class UpgradeFailure(BaseLibp2pError):
pass
@ -12,7 +15,3 @@ class SecurityUpgradeFailure(UpgradeFailure):
class MuxerUpgradeFailure(UpgradeFailure):
pass
class HandshakeFailure(BaseLibp2pError):
pass

View File

@ -6,6 +6,7 @@ from multiaddr import Multiaddr
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
@ -62,11 +63,15 @@ class TCP(ITransport):
dial a transport to peer listening on multiaddr
:param maddr: multiaddr of peer
:return: `RawConnection` if successful
:raise OpenConnectionError: raised when failed to open connection
"""
self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp"))
reader, writer = await asyncio.open_connection(self.host, self.port)
try:
reader, writer = await asyncio.open_connection(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)
return RawConnection(reader, writer, True)

View File

@ -4,16 +4,13 @@ from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
from libp2p.transport.exceptions import (
HandshakeFailure,
MuxerUpgradeFailure,
SecurityUpgradeFailure,
)
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
from libp2p.typing import TProtocol
from .listener_interface import IListener

View File

@ -65,10 +65,6 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
len_msg = await decode_uvarint_from_stream(reader)
data = await read_exactly(reader, len_msg)
if len(data) != len_msg:
raise ValueError(
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
)
return data

View File

@ -2,8 +2,8 @@ import asyncio
import pytest
from libp2p.network.exceptions import SwarmException
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import set_up_nodes_by_transport_opt
PROTOCOL_ID = "/chat/1.0.0"
@ -84,7 +84,7 @@ async def no_common_protocol(host_a, host_b):
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# try to creates a new new with a procotol not known by the other host
with pytest.raises(MultiselectClientError):
with pytest.raises(SwarmException):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])

View File

@ -1,6 +1,6 @@
import pytest
from libp2p.protocol_muxer.exceptions import MultiselectClientError
from libp2p.network.exceptions import SwarmException
from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different
@ -47,7 +47,7 @@ async def test_single_protocol_succeeds():
@pytest.mark.asyncio
async def test_single_protocol_fails():
with pytest.raises(MultiselectClientError):
with pytest.raises(SwarmException):
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
# Cleanup not reached on error
@ -77,7 +77,7 @@ async def test_multiple_protocol_second_is_valid_succeeds():
async def test_multiple_protocol_fails():
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(MultiselectClientError):
with pytest.raises(SwarmException):
await perform_simple_test("", protocols_for_client, protocols_for_listener)
# Cleanup not reached on error