Merge branch 'master' into fix/detection-of-close

This commit is contained in:
mhchia 2019-09-21 18:05:54 +08:00
commit e44c2145cc
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
22 changed files with 226 additions and 68 deletions

View File

@ -1,13 +1,16 @@
import asyncio import asyncio
import logging
from typing import List, Sequence from typing import List, Sequence
import multiaddr import multiaddr
from libp2p.host.exceptions import StreamFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
@ -22,6 +25,10 @@ from .host_interface import IHost
# telling it to listen on the given listen addresses. # telling it to listen on the given listen addresses.
logger = logging.getLogger("libp2p.network.basic_host")
logger.setLevel(logging.DEBUG)
class BasicHost(IHost): class BasicHost(IHost):
""" """
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
@ -103,9 +110,14 @@ class BasicHost(IHost):
net_stream = await self._network.new_stream(peer_id, protocol_ids) net_stream = await self._network.new_stream(peer_id, protocol_ids)
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of( try:
list(protocol_ids), MultiselectCommunicator(net_stream) selected_protocol = await self.multiselect_client.select_one_of(
) list(protocol_ids), MultiselectCommunicator(net_stream)
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await net_stream.reset()
raise StreamFailure("failt to open a stream to peer %s", peer_id) from error
net_stream.set_protocol(selected_protocol) net_stream.set_protocol(selected_protocol)
return net_stream return net_stream
@ -137,8 +149,12 @@ class BasicHost(IHost):
# Reference: `BasicHost.newStreamHandler` in Go. # Reference: `BasicHost.newStreamHandler` in Go.
async def _swarm_stream_handler(self, net_stream: INetStream) -> None: async def _swarm_stream_handler(self, net_stream: INetStream) -> None:
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
protocol, handler = await self.multiselect.negotiate( try:
MultiselectCommunicator(net_stream) protocol, handler = await self.multiselect.negotiate(
) MultiselectCommunicator(net_stream)
)
except MultiselectError:
await net_stream.reset()
return
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream)) asyncio.ensure_future(handler(net_stream))

15
libp2p/host/exceptions.py Normal file
View File

@ -0,0 +1,15 @@
from libp2p.exceptions import BaseLibp2pError
class HostException(BaseLibp2pError):
"""
A generic exception in `IHost`.
"""
class ConnectionFailure(HostException):
pass
class StreamFailure(HostException):
pass

View File

@ -0,0 +1,5 @@
from libp2p.io.exceptions import IOException
class RawConnError(IOException):
pass

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection from .raw_connection_interface import IRawConnection
@ -23,19 +24,33 @@ class RawConnection(IRawConnection):
self._drain_lock = asyncio.Lock() self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
self.writer.write(data) """
Raise `RawConnError` if the underlying connection breaks
"""
try:
self.writer.write(data)
except ConnectionResetError as error:
raise RawConnError(error)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug: # Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930 # https://bugs.python.org/issue29930
async with self._drain_lock: async with self._drain_lock:
await self.writer.drain() try:
await self.writer.drain()
except ConnectionResetError as error:
raise RawConnError(error)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
Read up to ``n`` bytes from the underlying stream. Read up to ``n`` bytes from the underlying stream.
This call is delegated directly to the underlying ``self.reader``. This call is delegated directly to the underlying ``self.reader``.
Raise `RawConnError` if the underlying connection breaks
""" """
return await self.reader.read(n) try:
return await self.reader.read(n)
except ConnectionResetError as error:
raise RawConnError(error)
async def close(self) -> None: async def close(self) -> None:
self.writer.close() self.writer.close()

View File

@ -1,7 +1,7 @@
from libp2p.exceptions import BaseLibp2pError from libp2p.io.exceptions import IOException
class StreamError(BaseLibp2pError): class StreamError(IOException):
pass pass

View File

@ -10,7 +10,11 @@ from libp2p.peer.peerstore import PeerStoreError
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.transport.exceptions import (
MuxerUpgradeFailure,
OpenConnectionError,
SecurityUpgradeFailure,
)
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
@ -99,7 +103,13 @@ class Swarm(INetwork):
multiaddr = self.router.find_peer(peer_id) multiaddr = self.router.find_peer(peer_id)
# Dial peer (connection to peer does not yet exist) # Dial peer (connection to peer does not yet exist)
# Transport dials peer (gets back a raw conn) # Transport dials peer (gets back a raw conn)
raw_conn = await self.transport.dial(multiaddr) try:
raw_conn = await self.transport.dial(multiaddr)
except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id)
raise SwarmException(
"fail to open connection to peer %s", peer_id
) from error
logger.debug("dialed peer %s over base transport", peer_id) logger.debug("dialed peer %s over base transport", peer_id)
@ -137,6 +147,7 @@ class Swarm(INetwork):
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
:raises SwarmException: raised when an error occurs
:return: net stream instance :return: net stream instance
""" """
logger.debug( logger.debug(

View File

@ -1,6 +1,10 @@
from libp2p.exceptions import BaseLibp2pError from libp2p.exceptions import BaseLibp2pError
class MultiselectCommunicatorError(BaseLibp2pError):
"""Raised when an error occurs during read/write via communicator"""
class MultiselectError(BaseLibp2pError): class MultiselectError(BaseLibp2pError):
"""Raised when an error occurs in multiselect process""" """Raised when an error occurs in multiselect process"""

View File

@ -2,7 +2,7 @@ from typing import Dict, Tuple
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .exceptions import MultiselectError from .exceptions import MultiselectCommunicatorError, MultiselectError
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
from .multiselect_muxer_interface import IMultiselectMuxer from .multiselect_muxer_interface import IMultiselectMuxer
@ -37,7 +37,7 @@ class Multiselect(IMultiselectMuxer):
Negotiate performs protocol selection Negotiate performs protocol selection
:param stream: stream to negotiate on :param stream: stream to negotiate on
:return: selected protocol name, handler function :return: selected protocol name, handler function
:raise Exception: negotiation failed exception :raise MultiselectError: raised when negotiation failed
""" """
# Perform handshake to ensure multiselect protocol IDs match # Perform handshake to ensure multiselect protocol IDs match
@ -46,7 +46,10 @@ class Multiselect(IMultiselectMuxer):
# Read and respond to commands until a valid protocol ID is sent # Read and respond to commands until a valid protocol ID is sent
while True: while True:
# Read message # Read message
command = await communicator.read() try:
command = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Command is ls or a protocol # Command is ls or a protocol
if command == "ls": if command == "ls":
@ -56,27 +59,39 @@ class Multiselect(IMultiselectMuxer):
protocol = TProtocol(command) protocol = TProtocol(command)
if protocol in self.handlers: if protocol in self.handlers:
# Tell counterparty we have decided on a protocol # Tell counterparty we have decided on a protocol
await communicator.write(protocol) try:
await communicator.write(protocol)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Return the decided on protocol # Return the decided on protocol
return protocol, self.handlers[protocol] return protocol, self.handlers[protocol]
# Tell counterparty this protocol was not found # Tell counterparty this protocol was not found
await communicator.write(PROTOCOL_NOT_FOUND_MSG) try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
async def handshake(self, communicator: IMultiselectCommunicator) -> None: async def handshake(self, communicator: IMultiselectCommunicator) -> None:
""" """
Perform handshake to agree on multiselect protocol Perform handshake to agree on multiselect protocol
:param communicator: communicator to use :param communicator: communicator to use
:raise Exception: error in handshake :raise MultiselectError: raised when handshake failed
""" """
# TODO: Use format used by go repo for messages # TODO: Use format used by go repo for messages
# Send our MULTISELECT_PROTOCOL_ID to other party # Send our MULTISELECT_PROTOCOL_ID to other party
await communicator.write(MULTISELECT_PROTOCOL_ID) try:
await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Read in the protocol ID from other party # Read in the protocol ID from other party
handshake_contents = await communicator.read() try:
handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectError(error)
# Confirm that the protocols are the same # Confirm that the protocols are the same
if not validate_handshake(handshake_contents): if not validate_handshake(handshake_contents):

View File

@ -2,7 +2,7 @@ from typing import Sequence
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .exceptions import MultiselectClientError from .exceptions import MultiselectClientError, MultiselectCommunicatorError
from .multiselect_client_interface import IMultiselectClient from .multiselect_client_interface import IMultiselectClient
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
@ -21,16 +21,22 @@ class MultiselectClient(IMultiselectClient):
Ensure that the client and multiselect Ensure that the client and multiselect
are both using the same multiselect protocol are both using the same multiselect protocol
:param stream: stream to communicate with multiselect over :param stream: stream to communicate with multiselect over
:raise Exception: multiselect protocol ID mismatch :raise MultiselectClientError: raised when handshake failed
""" """
# TODO: Use format used by go repo for messages # TODO: Use format used by go repo for messages
# Send our MULTISELECT_PROTOCOL_ID to counterparty # Send our MULTISELECT_PROTOCOL_ID to counterparty
await communicator.write(MULTISELECT_PROTOCOL_ID) try:
await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error:
raise MultiselectClientError(error)
# Read in the protocol ID from other party # Read in the protocol ID from other party
handshake_contents = await communicator.read() try:
handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error))
# Confirm that the protocols are the same # Confirm that the protocols are the same
if not validate_handshake(handshake_contents): if not validate_handshake(handshake_contents):
@ -48,6 +54,7 @@ class MultiselectClient(IMultiselectClient):
:param protocol: protocol to select :param protocol: protocol to select
:param stream: stream to communicate with multiselect over :param stream: stream to communicate with multiselect over
:return: selected protocol :return: selected protocol
:raise MultiselectClientError: raised when protocol negotiation failed
""" """
# Perform handshake to ensure multiselect protocol IDs match # Perform handshake to ensure multiselect protocol IDs match
await self.handshake(communicator) await self.handshake(communicator)
@ -71,15 +78,21 @@ class MultiselectClient(IMultiselectClient):
Try to select the given protocol or raise exception if fails Try to select the given protocol or raise exception if fails
:param communicator: communicator to use to communicate with counterparty :param communicator: communicator to use to communicate with counterparty
:param protocol: protocol to select :param protocol: protocol to select
:raise Exception: error in protocol selection :raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol :return: selected protocol
""" """
# Tell counterparty we want to use protocol # Tell counterparty we want to use protocol
await communicator.write(protocol) try:
await communicator.write(protocol)
except MultiselectCommunicatorError as error:
raise MultiselectClientError(error)
# Get what counterparty says in response # Get what counterparty says in response
response = await communicator.read() try:
response = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error))
# Return protocol if response is equal to protocol or raise error # Return protocol if response is equal to protocol or raise error
if response == protocol: if response == protocol:

View File

@ -1,6 +1,9 @@
from libp2p.exceptions import ParseError
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
from libp2p.utils import encode_delim, read_delim from libp2p.utils import encode_delim, read_delim
from .exceptions import MultiselectCommunicatorError
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
@ -11,9 +14,26 @@ class MultiselectCommunicator(IMultiselectCommunicator):
self.read_writer = read_writer self.read_writer = read_writer
async def write(self, msg_str: str) -> None: async def write(self, msg_str: str) -> None:
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
"""
msg_bytes = encode_delim(msg_str.encode()) msg_bytes = encode_delim(msg_str.encode())
await self.read_writer.write(msg_bytes) try:
await self.read_writer.write(msg_bytes)
except IOException:
raise MultiselectCommunicatorError(
"fail to write to multiselect communicator"
)
async def read(self) -> str: async def read(self) -> str:
data = await read_delim(self.read_writer) """
:raise MultiselectCommunicatorError: raised when failed to read from underlying reader
"""
try:
data = await read_delim(self.read_writer)
# `IOException` includes `IncompleteReadError` and `StreamError`
except (ParseError, IOException, ValueError):
raise MultiselectCommunicatorError(
"fail to read from multiselect communicator"
)
return data.decode() return data.decode()

View File

@ -16,8 +16,11 @@ from typing import (
import base58 import base58
from lru import LRU from lru import LRU
from libp2p.exceptions import ValidationError from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.exceptions import SwarmException
from libp2p.network.stream.exceptions import StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -154,7 +157,13 @@ class Pubsub:
peer_id = stream.mplex_conn.peer_id peer_id = stream.mplex_conn.peer_id
while True: while True:
incoming: bytes = await read_varint_prefixed_bytes(stream) try:
incoming: bytes = await read_varint_prefixed_bytes(stream)
except (ParseError, IncompleteReadError) as error:
logger.debug(
"read corrupted data from peer %s, error=%s", peer_id, error
)
continue
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
if rpc_incoming.publish: if rpc_incoming.publish:
@ -228,10 +237,20 @@ class Pubsub:
on one of the supported pubsub protocols. on one of the supported pubsub protocols.
:param stream: newly created stream :param stream: newly created stream
""" """
await self.continuously_read_stream(stream) try:
await self.continuously_read_stream(stream)
except (StreamEOF, StreamReset) as error:
logger.debug("fail to read from stream, error=%s", error)
stream.reset()
# TODO: what to do when the stream is terminated?
# disconnect the peer?
async def _handle_new_peer(self, peer_id: ID) -> None: async def _handle_new_peer(self, peer_id: ID) -> None:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols) try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
except SwarmException as error:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
return
self.peers[peer_id] = stream self.peers[peer_id] = stream

View File

@ -0,0 +1,5 @@
from libp2p.exceptions import BaseLibp2pError
class HandshakeFailure(BaseLibp2pError):
pass

View File

@ -4,12 +4,13 @@ from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.pb import crypto_pb2 from libp2p.crypto.pb import crypto_pb2
from libp2p.crypto.utils import pubkey_from_protobuf from libp2p.crypto.utils import pubkey_from_protobuf
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.exceptions import RawConnError
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.base_session import BaseSession from libp2p.security.base_session import BaseSession
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.exceptions import HandshakeFailure
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
@ -44,12 +45,21 @@ class InsecureSession(BaseSession):
await self.conn.close() await self.conn.close()
async def run_handshake(self) -> None: async def run_handshake(self) -> None:
"""
Raise `HandshakeFailure` when handshake failed
"""
msg = make_exchange_message(self.local_private_key.get_public_key()) msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString() msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
await self.write(encoded_msg_bytes) try:
await self.write(encoded_msg_bytes)
except RawConnError:
raise HandshakeFailure("connection closed")
remote_msg_bytes = await read_fixedint_prefixed(self.conn) try:
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
except RawConnError:
raise HandshakeFailure("connection closed")
remote_msg = plaintext_pb2.Exchange() remote_msg = plaintext_pb2.Exchange()
remote_msg.ParseFromString(remote_msg_bytes) remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id) received_peer_id = ID(remote_msg.id)

View File

@ -1,4 +1,7 @@
class SecioException(Exception): from libp2p.security.exceptions import HandshakeFailure
class SecioException(HandshakeFailure):
pass pass
@ -19,9 +22,9 @@ class InvalidSignatureOnExchange(SecioException):
pass pass
class HandshakeFailed(SecioException):
pass
class IncompatibleChoices(SecioException): class IncompatibleChoices(SecioException):
pass pass
class InconsistentNonce(SecioException):
pass

View File

@ -16,6 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey
from libp2p.crypto.key_exchange import create_ephemeral_key_pair from libp2p.crypto.key_exchange import create_ephemeral_key_pair
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.serialization import deserialize_public_key from libp2p.crypto.serialization import deserialize_public_key
from libp2p.io.exceptions import IOException
from libp2p.io.msgio import MsgIOReadWriter from libp2p.io.msgio import MsgIOReadWriter
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID as PeerID from libp2p.peer.id import ID as PeerID
@ -24,8 +25,8 @@ from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from .exceptions import ( from .exceptions import (
HandshakeFailed,
IncompatibleChoices, IncompatibleChoices,
InconsistentNonce,
InvalidSignatureOnExchange, InvalidSignatureOnExchange,
PeerMismatchException, PeerMismatchException,
SecioException, SecioException,
@ -399,6 +400,8 @@ async def create_secure_session(
Attempt the initial `secio` handshake with the remote peer. Attempt the initial `secio` handshake with the remote peer.
If successful, return an object that provides secure communication to the If successful, return an object that provides secure communication to the
``remote_peer``. ``remote_peer``.
Raise `SecioException` when `conn` closed.
Raise `InconsistentNonce` when handshake failed
""" """
msg_io = MsgIOReadWriter(conn) msg_io = MsgIOReadWriter(conn)
try: try:
@ -408,14 +411,21 @@ async def create_secure_session(
except SecioException as e: except SecioException as e:
await conn.close() await conn.close()
raise e raise e
# `IOException` includes errors raised while read from/write to raw connection
except IOException:
raise SecioException("connection closed")
initiator = remote_peer is not None initiator = remote_peer is not None
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator) session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
received_nonce = await _finish_handshake(session, remote_nonce) try:
received_nonce = await _finish_handshake(session, remote_nonce)
# `IOException` includes errors raised while read from/write to raw connection
except IOException:
raise SecioException("connection closed")
if received_nonce != local_nonce: if received_nonce != local_nonce:
await conn.close() await conn.close()
raise HandshakeFailed() raise InconsistentNonce()
return session return session

View File

@ -2,6 +2,7 @@ import asyncio
from typing import Any # noqa: F401 from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple from typing import Awaitable, Dict, List, Optional, Tuple
from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -197,14 +198,13 @@ class Mplex(IMuxedConn):
""" """
# FIXME: No timeout is used in Go implementation. # FIXME: No timeout is used in Go implementation.
# Timeout is set to a relatively small value to alleviate wait time to exit
# loop in handle_incoming
try: try:
header = await decode_uvarint_from_stream(self.secured_conn) header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for( message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5 read_varint_prefixed_bytes(self.secured_conn), timeout=5
) )
except (ConnectionResetError, IncompleteReadError) as error: # TODO: Catch RawConnError?
except (ParseError, IncompleteReadError) as error:
raise MplexUnavailable( raise MplexUnavailable(
"failed to read messages correctly from the underlying connection" "failed to read messages correctly from the underlying connection"
) from error ) from error

View File

@ -1,7 +1,10 @@
from libp2p.exceptions import BaseLibp2pError from libp2p.exceptions import BaseLibp2pError
# TODO: Add `BaseLibp2pError` and `UpgradeFailure` can inherit from it? class OpenConnectionError(BaseLibp2pError):
pass
class UpgradeFailure(BaseLibp2pError): class UpgradeFailure(BaseLibp2pError):
pass pass
@ -12,7 +15,3 @@ class SecurityUpgradeFailure(UpgradeFailure):
class MuxerUpgradeFailure(UpgradeFailure): class MuxerUpgradeFailure(UpgradeFailure):
pass pass
class HandshakeFailure(BaseLibp2pError):
pass

View File

@ -6,6 +6,7 @@ from multiaddr import Multiaddr
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler from libp2p.transport.typing import THandler
@ -62,11 +63,15 @@ class TCP(ITransport):
dial a transport to peer listening on multiaddr dial a transport to peer listening on multiaddr
:param maddr: multiaddr of peer :param maddr: multiaddr of peer
:return: `RawConnection` if successful :return: `RawConnection` if successful
:raise OpenConnectionError: raised when failed to open connection
""" """
self.host = maddr.value_for_protocol("ip4") self.host = maddr.value_for_protocol("ip4")
self.port = int(maddr.value_for_protocol("tcp")) self.port = int(maddr.value_for_protocol("tcp"))
reader, writer = await asyncio.open_connection(self.host, self.port) try:
reader, writer = await asyncio.open_connection(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)
return RawConnection(reader, writer, True) return RawConnection(reader, writer, True)

View File

@ -3,16 +3,13 @@ from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
from libp2p.transport.exceptions import ( from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
HandshakeFailure,
MuxerUpgradeFailure,
SecurityUpgradeFailure,
)
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .listener_interface import IListener from .listener_interface import IListener

View File

@ -59,10 +59,6 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
async def read_varint_prefixed_bytes(reader: Reader) -> bytes: async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
len_msg = await decode_uvarint_from_stream(reader) len_msg = await decode_uvarint_from_stream(reader)
data = await read_exactly(reader, len_msg) data = await read_exactly(reader, len_msg)
if len(data) != len_msg:
raise ValueError(
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
)
return data return data

View File

@ -2,8 +2,8 @@ import asyncio
import pytest import pytest
from libp2p.host.exceptions import StreamFailure
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import set_up_nodes_by_transport_opt from tests.utils import set_up_nodes_by_transport_opt
PROTOCOL_ID = "/chat/1.0.0" PROTOCOL_ID = "/chat/1.0.0"
@ -84,7 +84,7 @@ async def no_common_protocol(host_a, host_b):
host_a.set_stream_handler(PROTOCOL_ID, stream_handler) host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
# try to creates a new new with a procotol not known by the other host # try to creates a new new with a procotol not known by the other host
with pytest.raises(MultiselectClientError): with pytest.raises(StreamFailure):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"]) await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])

View File

@ -1,6 +1,6 @@
import pytest import pytest
from libp2p.protocol_muxer.exceptions import MultiselectClientError from libp2p.host.exceptions import StreamFailure
from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different
@ -47,7 +47,7 @@ async def test_single_protocol_succeeds():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_protocol_fails(): async def test_single_protocol_fails():
with pytest.raises(MultiselectClientError): with pytest.raises(StreamFailure):
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
# Cleanup not reached on error # Cleanup not reached on error
@ -77,7 +77,7 @@ async def test_multiple_protocol_second_is_valid_succeeds():
async def test_multiple_protocol_fails(): async def test_multiple_protocol_fails():
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(MultiselectClientError): with pytest.raises(StreamFailure):
await perform_simple_test("", protocols_for_client, protocols_for_listener) await perform_simple_test("", protocols_for_client, protocols_for_listener)
# Cleanup not reached on error # Cleanup not reached on error