Merge pull request #299 from NIC619/add_more_error_handling
Add more error handling
This commit is contained in:
commit
85457fa308
5
libp2p/network/connection/exceptions.py
Normal file
5
libp2p/network/connection/exceptions.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from libp2p.io.exceptions import IOException
|
||||||
|
|
||||||
|
|
||||||
|
class RawConnError(IOException):
|
||||||
|
pass
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from .exceptions import RawConnError
|
||||||
from .raw_connection_interface import IRawConnection
|
from .raw_connection_interface import IRawConnection
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,19 +24,33 @@ class RawConnection(IRawConnection):
|
||||||
self._drain_lock = asyncio.Lock()
|
self._drain_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Raise `RawConnError` if the underlying connection breaks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
self.writer.write(data)
|
self.writer.write(data)
|
||||||
|
except ConnectionResetError as error:
|
||||||
|
raise RawConnError(error)
|
||||||
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||||
# Use a lock to serialize drain() calls. Circumvents this bug:
|
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||||
# https://bugs.python.org/issue29930
|
# https://bugs.python.org/issue29930
|
||||||
async with self._drain_lock:
|
async with self._drain_lock:
|
||||||
|
try:
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
except ConnectionResetError:
|
||||||
|
raise RawConnError()
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
"""
|
"""
|
||||||
Read up to ``n`` bytes from the underlying stream.
|
Read up to ``n`` bytes from the underlying stream.
|
||||||
This call is delegated directly to the underlying ``self.reader``.
|
This call is delegated directly to the underlying ``self.reader``.
|
||||||
|
|
||||||
|
Raise `RawConnError` if the underlying connection breaks
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return await self.reader.read(n)
|
return await self.reader.read(n)
|
||||||
|
except ConnectionResetError as error:
|
||||||
|
raise RawConnError(error)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from libp2p.exceptions import BaseLibp2pError
|
from libp2p.io.exceptions import IOException
|
||||||
|
|
||||||
|
|
||||||
class StreamError(BaseLibp2pError):
|
class StreamError(IOException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,17 @@ from multiaddr import Multiaddr
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.peer.peerstore import PeerStoreError
|
from libp2p.peer.peerstore import PeerStoreError
|
||||||
from libp2p.peer.peerstore_interface import IPeerStore
|
from libp2p.peer.peerstore_interface import IPeerStore
|
||||||
|
from libp2p.protocol_muxer.exceptions import MultiselectClientError
|
||||||
from libp2p.protocol_muxer.multiselect import Multiselect
|
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||||
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
|
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
|
||||||
from libp2p.routing.interfaces import IPeerRouting
|
from libp2p.routing.interfaces import IPeerRouting
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
|
from libp2p.transport.exceptions import (
|
||||||
|
MuxerUpgradeFailure,
|
||||||
|
OpenConnectionError,
|
||||||
|
SecurityUpgradeFailure,
|
||||||
|
)
|
||||||
from libp2p.transport.listener_interface import IListener
|
from libp2p.transport.listener_interface import IListener
|
||||||
from libp2p.transport.transport_interface import ITransport
|
from libp2p.transport.transport_interface import ITransport
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
@ -117,7 +122,13 @@ class Swarm(INetwork):
|
||||||
multiaddr = self.router.find_peer(peer_id)
|
multiaddr = self.router.find_peer(peer_id)
|
||||||
# Dial peer (connection to peer does not yet exist)
|
# Dial peer (connection to peer does not yet exist)
|
||||||
# Transport dials peer (gets back a raw conn)
|
# Transport dials peer (gets back a raw conn)
|
||||||
|
try:
|
||||||
raw_conn = await self.transport.dial(multiaddr)
|
raw_conn = await self.transport.dial(multiaddr)
|
||||||
|
except OpenConnectionError as error:
|
||||||
|
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||||
|
raise SwarmException(
|
||||||
|
"fail to open connection to peer %s", peer_id
|
||||||
|
) from error
|
||||||
|
|
||||||
logger.debug("dialed peer %s over base transport", peer_id)
|
logger.debug("dialed peer %s over base transport", peer_id)
|
||||||
|
|
||||||
|
@ -162,6 +173,7 @@ class Swarm(INetwork):
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id of destination
|
:param peer_id: peer_id of destination
|
||||||
:param protocol_id: protocol id
|
:param protocol_id: protocol id
|
||||||
|
:raises SwarmException: raised when an error occurs
|
||||||
:return: net stream instance
|
:return: net stream instance
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -176,9 +188,16 @@ class Swarm(INetwork):
|
||||||
muxed_stream = await muxed_conn.open_stream()
|
muxed_stream = await muxed_conn.open_stream()
|
||||||
|
|
||||||
# Perform protocol muxing to determine protocol to use
|
# Perform protocol muxing to determine protocol to use
|
||||||
|
try:
|
||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids), MultiselectCommunicator(muxed_stream)
|
list(protocol_ids), MultiselectCommunicator(muxed_stream)
|
||||||
)
|
)
|
||||||
|
except MultiselectClientError as error:
|
||||||
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
|
await muxed_stream.reset()
|
||||||
|
raise SwarmException(
|
||||||
|
"failt to open a stream to peer %s", peer_id
|
||||||
|
) from error
|
||||||
|
|
||||||
# Create a net stream with the selected protocol
|
# Create a net stream with the selected protocol
|
||||||
net_stream = NetStream(muxed_stream)
|
net_stream = NetStream(muxed_stream)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from libp2p.exceptions import BaseLibp2pError
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiselectCommunicatorError(BaseLibp2pError):
|
||||||
|
"""Raised when an error occurs during read/write via communicator"""
|
||||||
|
|
||||||
|
|
||||||
class MultiselectError(BaseLibp2pError):
|
class MultiselectError(BaseLibp2pError):
|
||||||
"""Raised when an error occurs in multiselect process"""
|
"""Raised when an error occurs in multiselect process"""
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, Tuple
|
||||||
|
|
||||||
from libp2p.typing import StreamHandlerFn, TProtocol
|
from libp2p.typing import StreamHandlerFn, TProtocol
|
||||||
|
|
||||||
from .exceptions import MultiselectError
|
from .exceptions import MultiselectCommunicatorError, MultiselectError
|
||||||
from .multiselect_communicator_interface import IMultiselectCommunicator
|
from .multiselect_communicator_interface import IMultiselectCommunicator
|
||||||
from .multiselect_muxer_interface import IMultiselectMuxer
|
from .multiselect_muxer_interface import IMultiselectMuxer
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class Multiselect(IMultiselectMuxer):
|
||||||
Negotiate performs protocol selection
|
Negotiate performs protocol selection
|
||||||
:param stream: stream to negotiate on
|
:param stream: stream to negotiate on
|
||||||
:return: selected protocol name, handler function
|
:return: selected protocol name, handler function
|
||||||
:raise Exception: negotiation failed exception
|
:raise MultiselectError: raised when negotiation failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Perform handshake to ensure multiselect protocol IDs match
|
# Perform handshake to ensure multiselect protocol IDs match
|
||||||
|
@ -46,7 +46,10 @@ class Multiselect(IMultiselectMuxer):
|
||||||
# Read and respond to commands until a valid protocol ID is sent
|
# Read and respond to commands until a valid protocol ID is sent
|
||||||
while True:
|
while True:
|
||||||
# Read message
|
# Read message
|
||||||
|
try:
|
||||||
command = await communicator.read()
|
command = await communicator.read()
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError(error)
|
||||||
|
|
||||||
# Command is ls or a protocol
|
# Command is ls or a protocol
|
||||||
if command == "ls":
|
if command == "ls":
|
||||||
|
@ -56,27 +59,39 @@ class Multiselect(IMultiselectMuxer):
|
||||||
protocol = TProtocol(command)
|
protocol = TProtocol(command)
|
||||||
if protocol in self.handlers:
|
if protocol in self.handlers:
|
||||||
# Tell counterparty we have decided on a protocol
|
# Tell counterparty we have decided on a protocol
|
||||||
|
try:
|
||||||
await communicator.write(protocol)
|
await communicator.write(protocol)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError(error)
|
||||||
|
|
||||||
# Return the decided on protocol
|
# Return the decided on protocol
|
||||||
return protocol, self.handlers[protocol]
|
return protocol, self.handlers[protocol]
|
||||||
# Tell counterparty this protocol was not found
|
# Tell counterparty this protocol was not found
|
||||||
|
try:
|
||||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError(error)
|
||||||
|
|
||||||
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
||||||
"""
|
"""
|
||||||
Perform handshake to agree on multiselect protocol
|
Perform handshake to agree on multiselect protocol
|
||||||
:param communicator: communicator to use
|
:param communicator: communicator to use
|
||||||
:raise Exception: error in handshake
|
:raise MultiselectError: raised when handshake failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Use format used by go repo for messages
|
# TODO: Use format used by go repo for messages
|
||||||
|
|
||||||
# Send our MULTISELECT_PROTOCOL_ID to other party
|
# Send our MULTISELECT_PROTOCOL_ID to other party
|
||||||
|
try:
|
||||||
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError(error)
|
||||||
|
|
||||||
# Read in the protocol ID from other party
|
# Read in the protocol ID from other party
|
||||||
|
try:
|
||||||
handshake_contents = await communicator.read()
|
handshake_contents = await communicator.read()
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError(error)
|
||||||
|
|
||||||
# Confirm that the protocols are the same
|
# Confirm that the protocols are the same
|
||||||
if not validate_handshake(handshake_contents):
|
if not validate_handshake(handshake_contents):
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Sequence
|
||||||
|
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
|
||||||
from .exceptions import MultiselectClientError
|
from .exceptions import MultiselectClientError, MultiselectCommunicatorError
|
||||||
from .multiselect_client_interface import IMultiselectClient
|
from .multiselect_client_interface import IMultiselectClient
|
||||||
from .multiselect_communicator_interface import IMultiselectCommunicator
|
from .multiselect_communicator_interface import IMultiselectCommunicator
|
||||||
|
|
||||||
|
@ -21,16 +21,22 @@ class MultiselectClient(IMultiselectClient):
|
||||||
Ensure that the client and multiselect
|
Ensure that the client and multiselect
|
||||||
are both using the same multiselect protocol
|
are both using the same multiselect protocol
|
||||||
:param stream: stream to communicate with multiselect over
|
:param stream: stream to communicate with multiselect over
|
||||||
:raise Exception: multiselect protocol ID mismatch
|
:raise MultiselectClientError: raised when handshake failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Use format used by go repo for messages
|
# TODO: Use format used by go repo for messages
|
||||||
|
|
||||||
# Send our MULTISELECT_PROTOCOL_ID to counterparty
|
# Send our MULTISELECT_PROTOCOL_ID to counterparty
|
||||||
|
try:
|
||||||
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError(error)
|
||||||
|
|
||||||
# Read in the protocol ID from other party
|
# Read in the protocol ID from other party
|
||||||
|
try:
|
||||||
handshake_contents = await communicator.read()
|
handshake_contents = await communicator.read()
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError(str(error))
|
||||||
|
|
||||||
# Confirm that the protocols are the same
|
# Confirm that the protocols are the same
|
||||||
if not validate_handshake(handshake_contents):
|
if not validate_handshake(handshake_contents):
|
||||||
|
@ -48,6 +54,7 @@ class MultiselectClient(IMultiselectClient):
|
||||||
:param protocol: protocol to select
|
:param protocol: protocol to select
|
||||||
:param stream: stream to communicate with multiselect over
|
:param stream: stream to communicate with multiselect over
|
||||||
:return: selected protocol
|
:return: selected protocol
|
||||||
|
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||||
"""
|
"""
|
||||||
# Perform handshake to ensure multiselect protocol IDs match
|
# Perform handshake to ensure multiselect protocol IDs match
|
||||||
await self.handshake(communicator)
|
await self.handshake(communicator)
|
||||||
|
@ -71,15 +78,21 @@ class MultiselectClient(IMultiselectClient):
|
||||||
Try to select the given protocol or raise exception if fails
|
Try to select the given protocol or raise exception if fails
|
||||||
:param communicator: communicator to use to communicate with counterparty
|
:param communicator: communicator to use to communicate with counterparty
|
||||||
:param protocol: protocol to select
|
:param protocol: protocol to select
|
||||||
:raise Exception: error in protocol selection
|
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||||
:return: selected protocol
|
:return: selected protocol
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Tell counterparty we want to use protocol
|
# Tell counterparty we want to use protocol
|
||||||
|
try:
|
||||||
await communicator.write(protocol)
|
await communicator.write(protocol)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError(error)
|
||||||
|
|
||||||
# Get what counterparty says in response
|
# Get what counterparty says in response
|
||||||
|
try:
|
||||||
response = await communicator.read()
|
response = await communicator.read()
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError(str(error))
|
||||||
|
|
||||||
# Return protocol if response is equal to protocol or raise error
|
# Return protocol if response is equal to protocol or raise error
|
||||||
if response == protocol:
|
if response == protocol:
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
from libp2p.exceptions import ParseError
|
||||||
from libp2p.io.abc import ReadWriteCloser
|
from libp2p.io.abc import ReadWriteCloser
|
||||||
|
from libp2p.io.exceptions import IOException
|
||||||
from libp2p.utils import encode_delim, read_delim
|
from libp2p.utils import encode_delim, read_delim
|
||||||
|
|
||||||
|
from .exceptions import MultiselectCommunicatorError
|
||||||
from .multiselect_communicator_interface import IMultiselectCommunicator
|
from .multiselect_communicator_interface import IMultiselectCommunicator
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,9 +14,26 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||||
self.read_writer = read_writer
|
self.read_writer = read_writer
|
||||||
|
|
||||||
async def write(self, msg_str: str) -> None:
|
async def write(self, msg_str: str) -> None:
|
||||||
|
"""
|
||||||
|
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||||
|
"""
|
||||||
msg_bytes = encode_delim(msg_str.encode())
|
msg_bytes = encode_delim(msg_str.encode())
|
||||||
|
try:
|
||||||
await self.read_writer.write(msg_bytes)
|
await self.read_writer.write(msg_bytes)
|
||||||
|
except IOException:
|
||||||
|
raise MultiselectCommunicatorError(
|
||||||
|
"fail to write to multiselect communicator"
|
||||||
|
)
|
||||||
|
|
||||||
async def read(self) -> str:
|
async def read(self) -> str:
|
||||||
|
"""
|
||||||
|
:raise MultiselectCommunicatorError: raised when failed to read from underlying reader
|
||||||
|
"""
|
||||||
|
try:
|
||||||
data = await read_delim(self.read_writer)
|
data = await read_delim(self.read_writer)
|
||||||
|
# `IOException` includes `IncompleteReadError` and `StreamError`
|
||||||
|
except (ParseError, IOException, ValueError):
|
||||||
|
raise MultiselectCommunicatorError(
|
||||||
|
"fail to read from multiselect communicator"
|
||||||
|
)
|
||||||
return data.decode()
|
return data.decode()
|
||||||
|
|
|
@ -16,8 +16,11 @@ from typing import (
|
||||||
import base58
|
import base58
|
||||||
from lru import LRU
|
from lru import LRU
|
||||||
|
|
||||||
from libp2p.exceptions import ValidationError
|
from libp2p.exceptions import ParseError, ValidationError
|
||||||
from libp2p.host.host_interface import IHost
|
from libp2p.host.host_interface import IHost
|
||||||
|
from libp2p.io.exceptions import IncompleteReadError
|
||||||
|
from libp2p.network.exceptions import SwarmException
|
||||||
|
from libp2p.network.stream.exceptions import StreamEOF, StreamReset
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
@ -154,7 +157,13 @@ class Pubsub:
|
||||||
peer_id = stream.mplex_conn.peer_id
|
peer_id = stream.mplex_conn.peer_id
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
try:
|
||||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||||
|
except (ParseError, IncompleteReadError) as error:
|
||||||
|
logger.debug(
|
||||||
|
"read corrupted data from peer %s, error=%s", peer_id, error
|
||||||
|
)
|
||||||
|
continue
|
||||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
rpc_incoming.ParseFromString(incoming)
|
rpc_incoming.ParseFromString(incoming)
|
||||||
if rpc_incoming.publish:
|
if rpc_incoming.publish:
|
||||||
|
@ -228,10 +237,20 @@ class Pubsub:
|
||||||
on one of the supported pubsub protocols.
|
on one of the supported pubsub protocols.
|
||||||
:param stream: newly created stream
|
:param stream: newly created stream
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
await self.continuously_read_stream(stream)
|
await self.continuously_read_stream(stream)
|
||||||
|
except (StreamEOF, StreamReset) as error:
|
||||||
|
logger.debug("fail to read from stream, error=%s", error)
|
||||||
|
stream.reset()
|
||||||
|
# TODO: what to do when the stream is terminated?
|
||||||
|
# disconnect the peer?
|
||||||
|
|
||||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||||
|
try:
|
||||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||||
|
except SwarmException as error:
|
||||||
|
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||||
|
return
|
||||||
|
|
||||||
self.peers[peer_id] = stream
|
self.peers[peer_id] = stream
|
||||||
|
|
||||||
|
|
5
libp2p/security/exceptions.py
Normal file
5
libp2p/security/exceptions.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class HandshakeFailure(BaseLibp2pError):
|
||||||
|
pass
|
|
@ -4,12 +4,13 @@ from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||||
from libp2p.crypto.pb import crypto_pb2
|
from libp2p.crypto.pb import crypto_pb2
|
||||||
from libp2p.crypto.utils import pubkey_from_protobuf
|
from libp2p.crypto.utils import pubkey_from_protobuf
|
||||||
from libp2p.io.abc import ReadWriteCloser
|
from libp2p.io.abc import ReadWriteCloser
|
||||||
|
from libp2p.network.connection.exceptions import RawConnError
|
||||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.base_session import BaseSession
|
from libp2p.security.base_session import BaseSession
|
||||||
from libp2p.security.base_transport import BaseSecureTransport
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
|
from libp2p.security.exceptions import HandshakeFailure
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.transport.exceptions import HandshakeFailure
|
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
|
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
|
||||||
|
|
||||||
|
@ -44,12 +45,21 @@ class InsecureSession(BaseSession):
|
||||||
await self.conn.close()
|
await self.conn.close()
|
||||||
|
|
||||||
async def run_handshake(self) -> None:
|
async def run_handshake(self) -> None:
|
||||||
|
"""
|
||||||
|
Raise `HandshakeFailure` when handshake failed
|
||||||
|
"""
|
||||||
msg = make_exchange_message(self.local_private_key.get_public_key())
|
msg = make_exchange_message(self.local_private_key.get_public_key())
|
||||||
msg_bytes = msg.SerializeToString()
|
msg_bytes = msg.SerializeToString()
|
||||||
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
||||||
|
try:
|
||||||
await self.write(encoded_msg_bytes)
|
await self.write(encoded_msg_bytes)
|
||||||
|
except RawConnError:
|
||||||
|
raise HandshakeFailure("connection closed")
|
||||||
|
|
||||||
|
try:
|
||||||
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
|
remote_msg_bytes = await read_fixedint_prefixed(self.conn)
|
||||||
|
except RawConnError:
|
||||||
|
raise HandshakeFailure("connection closed")
|
||||||
remote_msg = plaintext_pb2.Exchange()
|
remote_msg = plaintext_pb2.Exchange()
|
||||||
remote_msg.ParseFromString(remote_msg_bytes)
|
remote_msg.ParseFromString(remote_msg_bytes)
|
||||||
received_peer_id = ID(remote_msg.id)
|
received_peer_id = ID(remote_msg.id)
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
class SecioException(Exception):
|
from libp2p.security.exceptions import HandshakeFailure
|
||||||
|
|
||||||
|
|
||||||
|
class SecioException(HandshakeFailure):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,9 +22,9 @@ class InvalidSignatureOnExchange(SecioException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HandshakeFailed(SecioException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IncompatibleChoices(SecioException):
|
class IncompatibleChoices(SecioException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InconsistentNonce(SecioException):
|
||||||
|
pass
|
||||||
|
|
|
@ -16,6 +16,7 @@ from libp2p.crypto.ecc import ECCPublicKey
|
||||||
from libp2p.crypto.key_exchange import create_ephemeral_key_pair
|
from libp2p.crypto.key_exchange import create_ephemeral_key_pair
|
||||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||||
from libp2p.crypto.serialization import deserialize_public_key
|
from libp2p.crypto.serialization import deserialize_public_key
|
||||||
|
from libp2p.io.exceptions import IOException
|
||||||
from libp2p.io.msgio import MsgIOReadWriter
|
from libp2p.io.msgio import MsgIOReadWriter
|
||||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
from libp2p.peer.id import ID as PeerID
|
from libp2p.peer.id import ID as PeerID
|
||||||
|
@ -24,8 +25,8 @@ from libp2p.security.base_transport import BaseSecureTransport
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
HandshakeFailed,
|
|
||||||
IncompatibleChoices,
|
IncompatibleChoices,
|
||||||
|
InconsistentNonce,
|
||||||
InvalidSignatureOnExchange,
|
InvalidSignatureOnExchange,
|
||||||
PeerMismatchException,
|
PeerMismatchException,
|
||||||
SecioException,
|
SecioException,
|
||||||
|
@ -399,6 +400,8 @@ async def create_secure_session(
|
||||||
Attempt the initial `secio` handshake with the remote peer.
|
Attempt the initial `secio` handshake with the remote peer.
|
||||||
If successful, return an object that provides secure communication to the
|
If successful, return an object that provides secure communication to the
|
||||||
``remote_peer``.
|
``remote_peer``.
|
||||||
|
Raise `SecioException` when `conn` closed.
|
||||||
|
Raise `InconsistentNonce` when handshake failed
|
||||||
"""
|
"""
|
||||||
msg_io = MsgIOReadWriter(conn)
|
msg_io = MsgIOReadWriter(conn)
|
||||||
try:
|
try:
|
||||||
|
@ -408,14 +411,21 @@ async def create_secure_session(
|
||||||
except SecioException as e:
|
except SecioException as e:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
raise e
|
raise e
|
||||||
|
# `IOException` includes errors raised while read from/write to raw connection
|
||||||
|
except IOException:
|
||||||
|
raise SecioException("connection closed")
|
||||||
|
|
||||||
initiator = remote_peer is not None
|
initiator = remote_peer is not None
|
||||||
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
|
session = _mk_session_from(local_private_key, session_parameters, msg_io, initiator)
|
||||||
|
|
||||||
|
try:
|
||||||
received_nonce = await _finish_handshake(session, remote_nonce)
|
received_nonce = await _finish_handshake(session, remote_nonce)
|
||||||
|
# `IOException` includes errors raised while read from/write to raw connection
|
||||||
|
except IOException:
|
||||||
|
raise SecioException("connection closed")
|
||||||
if received_nonce != local_nonce:
|
if received_nonce != local_nonce:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
raise HandshakeFailed()
|
raise InconsistentNonce()
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,11 @@ import asyncio
|
||||||
from typing import Any # noqa: F401
|
from typing import Any # noqa: F401
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from libp2p.exceptions import ParseError
|
||||||
|
from libp2p.io.exceptions import IncompleteReadError
|
||||||
from libp2p.network.typing import GenericProtocolHandlerFn
|
from libp2p.network.typing import GenericProtocolHandlerFn
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.protocol_muxer.exceptions import MultiselectError
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
@ -125,7 +128,13 @@ class Mplex(IMuxedConn):
|
||||||
"""
|
"""
|
||||||
stream = await self._initialize_stream(stream_id, name)
|
stream = await self._initialize_stream(stream_id, name)
|
||||||
# Perform protocol negotiation for the stream.
|
# Perform protocol negotiation for the stream.
|
||||||
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
|
try:
|
||||||
|
await self.generic_protocol_handler(stream)
|
||||||
|
except MultiselectError:
|
||||||
|
# Un-register and reset the stream
|
||||||
|
del self.streams[stream_id]
|
||||||
|
await stream.reset()
|
||||||
|
return
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||||
|
@ -178,7 +187,11 @@ class Mplex(IMuxedConn):
|
||||||
# `NewStream` for the same id is received twice...
|
# `NewStream` for the same id is received twice...
|
||||||
# TODO: Shutdown
|
# TODO: Shutdown
|
||||||
pass
|
pass
|
||||||
await self.accept_stream(stream_id, message.decode())
|
self._tasks.append(
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self.accept_stream(stream_id, message.decode())
|
||||||
|
)
|
||||||
|
)
|
||||||
elif flag in (
|
elif flag in (
|
||||||
HeaderTags.MessageInitiator.value,
|
HeaderTags.MessageInitiator.value,
|
||||||
HeaderTags.MessageReceiver.value,
|
HeaderTags.MessageReceiver.value,
|
||||||
|
@ -248,13 +261,15 @@ class Mplex(IMuxedConn):
|
||||||
# FIXME: No timeout is used in Go implementation.
|
# FIXME: No timeout is used in Go implementation.
|
||||||
# Timeout is set to a relatively small value to alleviate wait time to exit
|
# Timeout is set to a relatively small value to alleviate wait time to exit
|
||||||
# loop in handle_incoming
|
# loop in handle_incoming
|
||||||
|
try:
|
||||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||||
# TODO: Handle the case of EOF and other exceptions?
|
except ParseError:
|
||||||
|
return None, None, None
|
||||||
try:
|
try:
|
||||||
message = await asyncio.wait_for(
|
message = await asyncio.wait_for(
|
||||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
|
||||||
# TODO: Investigate what we should do if time is out.
|
# TODO: Investigate what we should do if time is out.
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from libp2p.exceptions import BaseLibp2pError
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add `BaseLibp2pError` and `UpgradeFailure` can inherit from it?
|
class OpenConnectionError(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UpgradeFailure(BaseLibp2pError):
|
class UpgradeFailure(BaseLibp2pError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -12,7 +15,3 @@ class SecurityUpgradeFailure(UpgradeFailure):
|
||||||
|
|
||||||
class MuxerUpgradeFailure(UpgradeFailure):
|
class MuxerUpgradeFailure(UpgradeFailure):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HandshakeFailure(BaseLibp2pError):
|
|
||||||
pass
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from multiaddr import Multiaddr
|
||||||
|
|
||||||
from libp2p.network.connection.raw_connection import RawConnection
|
from libp2p.network.connection.raw_connection import RawConnection
|
||||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
|
from libp2p.transport.exceptions import OpenConnectionError
|
||||||
from libp2p.transport.listener_interface import IListener
|
from libp2p.transport.listener_interface import IListener
|
||||||
from libp2p.transport.transport_interface import ITransport
|
from libp2p.transport.transport_interface import ITransport
|
||||||
from libp2p.transport.typing import THandler
|
from libp2p.transport.typing import THandler
|
||||||
|
@ -62,11 +63,15 @@ class TCP(ITransport):
|
||||||
dial a transport to peer listening on multiaddr
|
dial a transport to peer listening on multiaddr
|
||||||
:param maddr: multiaddr of peer
|
:param maddr: multiaddr of peer
|
||||||
:return: `RawConnection` if successful
|
:return: `RawConnection` if successful
|
||||||
|
:raise OpenConnectionError: raised when failed to open connection
|
||||||
"""
|
"""
|
||||||
self.host = maddr.value_for_protocol("ip4")
|
self.host = maddr.value_for_protocol("ip4")
|
||||||
self.port = int(maddr.value_for_protocol("tcp"))
|
self.port = int(maddr.value_for_protocol("tcp"))
|
||||||
|
|
||||||
|
try:
|
||||||
reader, writer = await asyncio.open_connection(self.host, self.port)
|
reader, writer = await asyncio.open_connection(self.host, self.port)
|
||||||
|
except (ConnectionAbortedError, ConnectionRefusedError) as error:
|
||||||
|
raise OpenConnectionError(error)
|
||||||
|
|
||||||
return RawConnection(reader, writer, True)
|
return RawConnection(reader, writer, True)
|
||||||
|
|
||||||
|
|
|
@ -4,16 +4,13 @@ from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||||
from libp2p.network.typing import GenericProtocolHandlerFn
|
from libp2p.network.typing import GenericProtocolHandlerFn
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
|
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
|
||||||
|
from libp2p.security.exceptions import HandshakeFailure
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||||
from libp2p.security.security_multistream import SecurityMultistream
|
from libp2p.security.security_multistream import SecurityMultistream
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn
|
from libp2p.stream_muxer.abc import IMuxedConn
|
||||||
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
|
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
|
||||||
from libp2p.transport.exceptions import (
|
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
|
||||||
HandshakeFailure,
|
|
||||||
MuxerUpgradeFailure,
|
|
||||||
SecurityUpgradeFailure,
|
|
||||||
)
|
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
|
||||||
from .listener_interface import IListener
|
from .listener_interface import IListener
|
||||||
|
|
|
@ -65,10 +65,6 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||||
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
||||||
len_msg = await decode_uvarint_from_stream(reader)
|
len_msg = await decode_uvarint_from_stream(reader)
|
||||||
data = await read_exactly(reader, len_msg)
|
data = await read_exactly(reader, len_msg)
|
||||||
if len(data) != len_msg:
|
|
||||||
raise ValueError(
|
|
||||||
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@ import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.network.exceptions import SwarmException
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.protocol_muxer.exceptions import MultiselectClientError
|
|
||||||
from tests.utils import set_up_nodes_by_transport_opt
|
from tests.utils import set_up_nodes_by_transport_opt
|
||||||
|
|
||||||
PROTOCOL_ID = "/chat/1.0.0"
|
PROTOCOL_ID = "/chat/1.0.0"
|
||||||
|
@ -84,7 +84,7 @@ async def no_common_protocol(host_a, host_b):
|
||||||
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
host_a.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||||
|
|
||||||
# try to creates a new new with a procotol not known by the other host
|
# try to creates a new new with a procotol not known by the other host
|
||||||
with pytest.raises(MultiselectClientError):
|
with pytest.raises(SwarmException):
|
||||||
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
|
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.protocol_muxer.exceptions import MultiselectClientError
|
from libp2p.network.exceptions import SwarmException
|
||||||
from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
|
from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
|
||||||
|
|
||||||
# TODO: Add tests for multiple streams being opened on different
|
# TODO: Add tests for multiple streams being opened on different
|
||||||
|
@ -47,7 +47,7 @@ async def test_single_protocol_succeeds():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_single_protocol_fails():
|
async def test_single_protocol_fails():
|
||||||
with pytest.raises(MultiselectClientError):
|
with pytest.raises(SwarmException):
|
||||||
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
|
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
|
||||||
|
|
||||||
# Cleanup not reached on error
|
# Cleanup not reached on error
|
||||||
|
@ -77,7 +77,7 @@ async def test_multiple_protocol_second_is_valid_succeeds():
|
||||||
async def test_multiple_protocol_fails():
|
async def test_multiple_protocol_fails():
|
||||||
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
|
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
|
||||||
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
|
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
|
||||||
with pytest.raises(MultiselectClientError):
|
with pytest.raises(SwarmException):
|
||||||
await perform_simple_test("", protocols_for_client, protocols_for_listener)
|
await perform_simple_test("", protocols_for_client, protocols_for_listener)
|
||||||
|
|
||||||
# Cleanup not reached on error
|
# Cleanup not reached on error
|
||||||
|
|
Loading…
Reference in New Issue
Block a user