Merge pull request #305 from mhchia/fix/change-notifee-and-add-tests-for-swarm-conn-and-mplex

Change `Notifee`, add tests for `SwarmConn` and `Mplex`
This commit is contained in:
Kevin Mai-Husan Chia 2019-09-24 14:02:58 +08:00 committed by GitHub
commit b53ca5708f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 635 additions and 499 deletions

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Mapping, Sequence from typing import Sequence
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
@ -15,10 +15,9 @@ from libp2p.routing.interfaces import IPeerRouting
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -74,8 +73,8 @@ def initialize_default_swarm(
key_pair: KeyPair, key_pair: KeyPair,
id_opt: ID = None, id_opt: ID = None,
transport_opt: Sequence[str] = None, transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None, muxer_opt: TMuxerOptions = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None, disc_opt: IPeerRouting = None,
) -> Swarm: ) -> Swarm:
@ -114,8 +113,8 @@ async def new_node(
key_pair: KeyPair = None, key_pair: KeyPair = None,
swarm_opt: INetwork = None, swarm_opt: INetwork = None,
transport_opt: Sequence[str] = None, transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None, muxer_opt: TMuxerOptions = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None, disc_opt: IPeerRouting = None,
) -> BasicHost: ) -> BasicHost:

View File

@ -1,4 +1,3 @@
import asyncio
import logging import logging
from typing import List, Sequence from typing import List, Sequence
@ -107,7 +106,7 @@ class BasicHost(IHost):
:return: stream: new stream created :return: stream: new stream created
""" """
net_stream = await self._network.new_stream(peer_id, protocol_ids) net_stream = await self._network.new_stream(peer_id)
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
try: try:
@ -157,4 +156,4 @@ class BasicHost(IHost):
await net_stream.reset() await net_stream.reset()
return return
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream)) await handler(net_stream)

View File

@ -7,7 +7,7 @@ from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer): class INetConn(Closer):
conn: IMuxedConn muxed_conn: IMuxedConn
@abstractmethod @abstractmethod
async def new_stream(self) -> INetStream: async def new_stream(self) -> INetStream:

View File

@ -16,15 +16,15 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee
class SwarmConn(INetConn): class SwarmConn(INetConn):
conn: IMuxedConn muxed_conn: IMuxedConn
swarm: "Swarm" swarm: "Swarm"
streams: Set[NetStream] streams: Set[NetStream]
event_closed: asyncio.Event event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
def __init__(self, conn: IMuxedConn, swarm: "Swarm") -> None: def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
self.conn = conn self.muxed_conn = muxed_conn
self.swarm = swarm self.swarm = swarm
self.streams = set() self.streams = set()
self.event_closed = asyncio.Event() self.event_closed = asyncio.Event()
@ -37,22 +37,26 @@ class SwarmConn(INetConn):
self.event_closed.set() self.event_closed.set()
self.swarm.remove_conn(self) self.swarm.remove_conn(self)
await self.conn.close() await self.muxed_conn.close()
# This is just for cleaning up state. The connection has already been closed. # This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it. # We *could* optimize this but it really isn't worth it.
for stream in self.streams: for stream in self.streams:
await stream.reset() await stream.reset()
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
asyncio.ensure_future(self._notify_disconnected())
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while True: while True:
try: try:
stream = await self.conn.accept_stream() stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable: except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn, # If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection. # we should break the loop and close the connection.
@ -62,22 +66,28 @@ class SwarmConn(INetConn):
await self.close() await self.close()
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: async def _call_stream_handler(self, net_stream: NetStream) -> None:
net_stream = await self._add_stream(muxed_stream) try:
if self.swarm.common_stream_handler is not None: await self.swarm.common_stream_handler(net_stream)
await self.run_task(self.swarm.common_stream_handler(net_stream)) # TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self.run_task(self._call_stream_handler(net_stream))
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream) self.streams.add(net_stream)
# Call notifiers since event occurred self.swarm.notify_opened_stream(net_stream)
for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream)
return net_stream return net_stream
async def _notify_disconnected(self) -> None: def _notify_disconnected(self) -> None:
for notifee in self.swarm.notifees: self.swarm.notify_disconnected(self)
await notifee.disconnected(self.swarm, self.conn)
async def start(self) -> None: async def start(self) -> None:
await self.run_task(self._handle_new_streams()) await self.run_task(self._handle_new_streams())
@ -86,8 +96,13 @@ class SwarmConn(INetConn):
self._tasks.append(asyncio.ensure_future(coro)) self._tasks.append(asyncio.ensure_future(coro))
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()
return await self._add_stream(muxed_stream) return self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]: async def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams) return tuple(self.streams)
def remove_stream(self, stream: NetStream) -> None:
if stream not in self.streams:
return
self.streams.remove(stream)

View File

@ -7,7 +7,7 @@ from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn
from .stream.net_stream_interface import INetStream from .stream.net_stream_interface import INetStream
@ -38,9 +38,7 @@ class INetwork(ABC):
""" """
@abstractmethod @abstractmethod
async def new_stream( async def new_stream(self, peer_id: ID) -> INetStream:
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_ids: available protocol ids to use for stream :param protocol_ids: available protocol ids to use for stream
@ -61,7 +59,7 @@ class INetwork(ABC):
""" """
@abstractmethod @abstractmethod
def notify(self, notifee: "INotifee") -> bool: def register_notifee(self, notifee: "INotifee") -> None:
""" """
:param notifee: object implementing Notifee interface :param notifee: object implementing Notifee interface
:return: true if notifee registered successfully, false otherwise :return: true if notifee registered successfully, false otherwise

View File

@ -3,8 +3,8 @@ from typing import TYPE_CHECKING
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn
if TYPE_CHECKING: if TYPE_CHECKING:
from .network_interface import INetwork # noqa: F401 from .network_interface import INetwork # noqa: F401
@ -26,14 +26,14 @@ class INotifee(ABC):
""" """
@abstractmethod @abstractmethod
async def connected(self, network: "INetwork", conn: IMuxedConn) -> None: async def connected(self, network: "INetwork", conn: INetConn) -> None:
""" """
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
@abstractmethod @abstractmethod
async def disconnected(self, network: "INetwork", conn: IMuxedConn) -> None: async def disconnected(self, network: "INetwork", conn: INetConn) -> None:
""" """
:param network: network the connection was closed on :param network: network the connection was closed on
:param conn: connection that was closed :param conn: connection that was closed

View File

@ -1,4 +1,6 @@
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from typing import Optional
from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.stream_muxer.exceptions import ( from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed, MuxedStreamClosed,
MuxedStreamEOF, MuxedStreamEOF,
@ -16,13 +18,11 @@ from .net_stream_interface import INetStream
class NetStream(INetStream): class NetStream(INetStream):
muxed_stream: IMuxedStream muxed_stream: IMuxedStream
# TODO: Why we expose `mplex_conn` here? protocol_id: Optional[TProtocol]
mplex_conn: IMuxedConn
protocol_id: TProtocol
def __init__(self, muxed_stream: IMuxedStream) -> None: def __init__(self, muxed_stream: IMuxedStream) -> None:
self.muxed_stream = muxed_stream self.muxed_stream = muxed_stream
self.mplex_conn = muxed_stream.mplex_conn self.muxed_conn = muxed_stream.muxed_conn
self.protocol_id = None self.protocol_id = None
def get_protocol(self) -> TProtocol: def get_protocol(self) -> TProtocol:
@ -68,3 +68,7 @@ class NetStream(INetStream):
async def reset(self) -> None: async def reset(self) -> None:
await self.muxed_stream.reset() await self.muxed_stream.reset()
# TODO: `remove`: Called by close and write when the stream is in specific states.
# It notifies `ClosedStream` after `SwarmConn.remove_stream` is called.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501

View File

@ -7,7 +7,7 @@ from libp2p.typing import TProtocol
class INetStream(ReadWriteCloser): class INetStream(ReadWriteCloser):
mplex_conn: IMuxedConn muxed_conn: IMuxedConn
@abstractmethod @abstractmethod
def get_protocol(self) -> TProtocol: def get_protocol(self) -> TProtocol:

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional
from multiaddr import Multiaddr from multiaddr import Multiaddr
@ -18,7 +18,7 @@ from libp2p.transport.exceptions import (
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn
from .connection.raw_connection import RawConnection from .connection.raw_connection import RawConnection
from .connection.swarm_connection import SwarmConn from .connection.swarm_connection import SwarmConn
@ -141,20 +141,14 @@ class Swarm(INetwork):
return swarm_conn return swarm_conn
async def new_stream( async def new_stream(self, peer_id: ID) -> INetStream:
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
:raises SwarmException: raised when an error occurs :raises SwarmException: raised when an error occurs
:return: net stream instance :return: net stream instance
""" """
logger.debug( logger.debug("attempting to open a stream to peer %s", peer_id)
"attempting to open a stream to peer %s, over one of the protocols %s",
peer_id,
protocol_ids,
)
swarm_conn = await self.dial_peer(peer_id) swarm_conn = await self.dial_peer(peer_id)
@ -229,8 +223,7 @@ class Swarm(INetwork):
await listener.listen(maddr) await listener.listen(maddr)
# Call notifiers since event occurred # Call notifiers since event occurred
for notifee in self.notifees: self.notify_listen(maddr)
await notifee.listen(self, maddr)
return True return True
except IOError: except IOError:
@ -240,16 +233,6 @@ class Swarm(INetwork):
# No maddr succeeded # No maddr succeeded
return False return False
def notify(self, notifee: INotifee) -> bool:
"""
:param notifee: object implementing Notifee interface
:return: true if notifee registered successfully, false otherwise
"""
if isinstance(notifee, INotifee):
self.notifees.append(notifee)
return True
return False
def add_router(self, router: IPeerRouting) -> None: def add_router(self, router: IPeerRouting) -> None:
self.router = router self.router = router
@ -288,9 +271,7 @@ class Swarm(INetwork):
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred # Call notifiers since event occurred
for notifee in self.notifees: self.notify_connected(swarm_conn)
# TODO: Call with other type of conn?
await notifee.connected(self, muxed_conn)
await swarm_conn.start() await swarm_conn.start()
return swarm_conn return swarm_conn
@ -298,9 +279,38 @@ class Swarm(INetwork):
""" """
Simply remove the connection from Swarm's records, without closing the connection. Simply remove the connection from Swarm's records, without closing the connection.
""" """
peer_id = swarm_conn.conn.peer_id peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to remove the exact connection, # TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future. # if we have several connections per peer in the future.
del self.connections[peer_id] del self.connections[peer_id]
# Notifee
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
def register_notifee(self, notifee: INotifee) -> None:
"""
:param notifee: object implementing Notifee interface
:return: true if notifee registered successfully, false otherwise
"""
self.notifees.append(notifee)
def notify_opened_stream(self, stream: INetStream) -> None:
asyncio.gather(
*[notifee.opened_stream(self, stream) for notifee in self.notifees]
)
# TODO: `notify_closed_stream`
def notify_connected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees])
def notify_disconnected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees])
def notify_listen(self, multiaddr: Multiaddr) -> None:
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees])
# TODO: `notify_listen_close`

View File

@ -20,10 +20,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
msg_bytes = encode_delim(msg_str.encode()) msg_bytes = encode_delim(msg_str.encode())
try: try:
await self.read_writer.write(msg_bytes) await self.read_writer.write(msg_bytes)
except IOException: except IOException as error:
raise MultiselectCommunicatorError( raise MultiselectCommunicatorError(
"fail to write to multiselect communicator" "fail to write to multiselect communicator"
) ) from error
async def read(self) -> str: async def read(self) -> str:
""" """
@ -32,8 +32,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
try: try:
data = await read_delim(self.read_writer) data = await read_delim(self.read_writer)
# `IOException` includes `IncompleteReadError` and `StreamError` # `IOException` includes `IncompleteReadError` and `StreamError`
except (ParseError, IOException, ValueError): except (ParseError, IOException) as error:
raise MultiselectCommunicatorError( raise MultiselectCommunicatorError(
"fail to read from multiselect communicator" "fail to read from multiselect communicator"
) ) from error
return data.decode() return data.decode()

View File

@ -98,7 +98,7 @@ class Pubsub:
# Register a notifee # Register a notifee
self.peer_queue = asyncio.Queue() self.peer_queue = asyncio.Queue()
self.host.get_network().notify(PubsubNotifee(self.peer_queue)) self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue))
# Register stream handlers for each pubsub router protocol to handle # Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols # the pubsub streams opened on those protocols
@ -154,7 +154,7 @@ class Pubsub:
messages from other nodes messages from other nodes
:param stream: stream to continously read from :param stream: stream to continously read from
""" """
peer_id = stream.mplex_conn.peer_id peer_id = stream.muxed_conn.peer_id
while True: while True:
try: try:

View File

@ -2,10 +2,10 @@ from typing import TYPE_CHECKING
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
from libp2p.network.notifee_interface import INotifee from libp2p.network.notifee_interface import INotifee
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio # noqa: F401 import asyncio # noqa: F401
@ -29,16 +29,16 @@ class PubsubNotifee(INotifee):
async def closed_stream(self, network: INetwork, stream: INetStream) -> None: async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass pass
async def connected(self, network: INetwork, conn: IMuxedConn) -> None: async def connected(self, network: INetwork, conn: INetConn) -> None:
""" """
Add peer_id to initiator_peers_queue, so that this peer_id can be used to Add peer_id to initiator_peers_queue, so that this peer_id can be used to
create a stream and we only want to have one pubsub stream with each peer. create a stream and we only want to have one pubsub stream with each peer.
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
await self.initiator_peers_queue.put(conn.peer_id) await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
async def disconnected(self, network: INetwork, conn: IMuxedConn) -> None: async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:

View File

@ -1,6 +1,5 @@
from abc import ABC from abc import ABC
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.transport.typing import TSecurityOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -31,9 +31,7 @@ class SecurityMultistream(ABC):
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
def __init__( def __init__(self, secure_transports_by_protocol: TSecurityOptions) -> None:
self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport]
) -> None:
self.transports = OrderedDict() self.transports = OrderedDict()
self.multiselect = Multiselect() self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient() self.multiselect_client = MultiselectClient()

View File

@ -55,7 +55,7 @@ class IMuxedConn(ABC):
class IMuxedStream(ReadWriteCloser): class IMuxedStream(ReadWriteCloser):
mplex_conn: IMuxedConn muxed_conn: IMuxedConn
@abstractmethod @abstractmethod
async def reset(self) -> None: async def reset(self) -> None:

View File

@ -31,9 +31,6 @@ class Mplex(IMuxedConn):
secured_conn: ISecureConn secured_conn: ISecureConn
peer_id: ID peer_id: ID
# TODO: `dataIn` in go implementation. Should be size of 8.
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
# to let the `MplexStream`s know that EOF arrived (#235).
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock
@ -43,7 +40,6 @@ class Mplex(IMuxedConn):
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
# TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
""" """
create a new muxed connection create a new muxed connection

View File

@ -18,12 +18,13 @@ class MplexStream(IMuxedStream):
name: str name: str
stream_id: StreamID stream_id: StreamID
mplex_conn: "Mplex" muxed_conn: "Mplex"
read_deadline: int read_deadline: int
write_deadline: int write_deadline: int
close_lock: asyncio.Lock close_lock: asyncio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]" incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event event_local_closed: asyncio.Event
@ -32,15 +33,15 @@ class MplexStream(IMuxedStream):
_buf: bytearray _buf: bytearray
def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
""" """
create new MuxedStream in muxer create new MuxedStream in muxer
:param stream_id: stream id of this stream :param stream_id: stream id of this stream
:param mplex_conn: muxed connection of this muxed_stream :param muxed_conn: muxed connection of this muxed_stream
""" """
self.name = name self.name = name
self.stream_id = stream_id self.stream_id = stream_id
self.mplex_conn = mplex_conn self.muxed_conn = muxed_conn
self.read_deadline = None self.read_deadline = None
self.write_deadline = None self.write_deadline = None
self.event_local_closed = asyncio.Event() self.event_local_closed = asyncio.Event()
@ -147,7 +148,7 @@ class MplexStream(IMuxedStream):
if self.is_initiator if self.is_initiator
else HeaderTags.MessageReceiver else HeaderTags.MessageReceiver
) )
return await self.mplex_conn.send_message(flag, data, self.stream_id) return await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None: async def close(self) -> None:
""" """
@ -163,8 +164,8 @@ class MplexStream(IMuxedStream):
flag = ( flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
) )
# TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
await self.mplex_conn.send_message(flag, None, self.stream_id) await self.muxed_conn.send_message(flag, None, self.stream_id)
_is_remote_closed: bool _is_remote_closed: bool
async with self.close_lock: async with self.close_lock:
@ -173,8 +174,8 @@ class MplexStream(IMuxedStream):
if _is_remote_closed: if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict. # Both sides are closed, we can safely remove the buffer from the dict.
async with self.mplex_conn.streams_lock: async with self.muxed_conn.streams_lock:
del self.mplex_conn.streams[self.stream_id] del self.muxed_conn.streams[self.stream_id]
async def reset(self) -> None: async def reset(self) -> None:
""" """
@ -196,19 +197,19 @@ class MplexStream(IMuxedStream):
else HeaderTags.ResetReceiver else HeaderTags.ResetReceiver
) )
asyncio.ensure_future( asyncio.ensure_future(
self.mplex_conn.send_message(flag, None, self.stream_id) self.muxed_conn.send_message(flag, None, self.stream_id)
) )
await asyncio.sleep(0) await asyncio.sleep(0)
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()
async with self.mplex_conn.streams_lock: async with self.muxed_conn.streams_lock:
if ( if (
self.mplex_conn.streams is not None self.muxed_conn.streams is not None
and self.stream_id in self.mplex_conn.streams and self.stream_id in self.muxed_conn.streams
): ):
del self.mplex_conn.streams[self.stream_id] del self.muxed_conn.streams[self.stream_id]
# TODO deadline not in use # TODO deadline not in use
def set_deadline(self, ttl: int) -> bool: def set_deadline(self, ttl: int) -> bool:

View File

@ -1,5 +1,4 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping, Type
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.typing import TMuxerClass, TMuxerOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .abc import IMuxedConn from .abc import IMuxedConn
MuxerClassType = Type[IMuxedConn]
# FIXME: add negotiate timeout to `MuxerMultistream` # FIXME: add negotiate timeout to `MuxerMultistream`
DEFAULT_NEGOTIATE_TIMEOUT = 60 DEFAULT_NEGOTIATE_TIMEOUT = 60
@ -24,20 +22,18 @@ class MuxerMultistream:
""" """
# NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2. # NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2.
transports: "OrderedDict[TProtocol, MuxerClassType]" transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
def __init__( def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType]
) -> None:
self.transports = OrderedDict() self.transports = OrderedDict()
self.multiselect = Multiselect() self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient() self.multiselect_client = MultiselectClient()
for protocol, transport in muxer_transports_by_protocol.items(): for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport) self.add_transport(protocol, transport)
def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None: def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None:
""" """
Add a protocol and its corresponding transport to multistream-select(multiselect). Add a protocol and its corresponding transport to multistream-select(multiselect).
The order that a protocol is added is exactly the precedence it is negotiated in The order that a protocol is added is exactly the precedence it is negotiated in
@ -51,7 +47,7 @@ class MuxerMultistream:
self.transports[protocol] = transport self.transports[protocol] = transport
self.multiselect.add_handler(protocol, None) self.multiselect.add_handler(protocol, None)
async def select_transport(self, conn: IRawConnection) -> MuxerClassType: async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
""" """
Select a transport that both us and the node on the Select a transport that both us and the node on the
other end of conn support and agree on other end of conn support and agree on

View File

@ -1,4 +1,11 @@
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable from typing import Awaitable, Callable, Mapping, Type
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -1,16 +1,13 @@
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.exceptions import HandshakeFailure from libp2p.security.exceptions import HandshakeFailure
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream from libp2p.stream_muxer.muxer_multistream import MuxerMultistream
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
from libp2p.typing import TProtocol from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
from .listener_interface import IListener from .listener_interface import IListener
from .transport_interface import ITransport from .transport_interface import ITransport
@ -22,8 +19,8 @@ class TransportUpgrader:
def __init__( def __init__(
self, self,
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport], secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType], muxer_transports_by_protocol: TMuxerOptions,
): ):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)

View File

@ -73,9 +73,12 @@ def encode_delim(msg: bytes) -> bytes:
async def read_delim(reader: Reader) -> bytes: async def read_delim(reader: Reader) -> bytes:
msg_bytes = await read_varint_prefixed_bytes(reader) msg_bytes = await read_varint_prefixed_bytes(reader)
# TODO: Investigate if it is possible to have empty `msg_bytes` if len(msg_bytes) == 0:
if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n": raise ParseError(f"`len(msg_bytes)` should not be 0")
raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') if msg_bytes[-1:] != b"\n":
raise ParseError(
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes}'
)
return msg_bytes[:-1] return msg_bytes[:-1]

View File

@ -6,6 +6,7 @@ import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
@ -14,6 +15,9 @@ from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.transport.typing import TMuxerOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
@ -33,10 +37,10 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
def SwarmFactory(is_secure: bool) -> Swarm: def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm:
key_pair = generate_new_rsa_identity() key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(False, key_pair) sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt) return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
class ListeningSwarmFactory(factory.Factory): class ListeningSwarmFactory(factory.Factory):
@ -44,17 +48,22 @@ class ListeningSwarmFactory(factory.Factory):
model = Swarm model = Swarm
@classmethod @classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm: async def create_and_listen(
swarm = SwarmFactory(is_secure) cls, is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Swarm:
swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt)
await swarm.listen(LISTEN_MADDR) await swarm.listen(LISTEN_MADDR)
return swarm return swarm
@classmethod @classmethod
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]: ) -> Tuple[Swarm, ...]:
return await asyncio.gather( return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)] *[
cls.create_and_listen(is_secure, muxer_opt=muxer_opt)
for _ in range(number)
]
) )
@ -111,8 +120,12 @@ class PubsubFactory(factory.Factory):
cache_size = None cache_size = None
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: async def swarm_pair_factory(
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
)
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1] return swarms[0], swarms[1]
@ -128,11 +141,37 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
return hosts[0], hosts[1] return hosts[0], hosts[1]
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: async def swarm_conn_pair_factory(
# host_0, host_1 = await host_pair_factory() is_secure: bool, muxer_opt: TMuxerOptions = None
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] swarms = await swarm_pair_factory(is_secure)
# return mplex_conn_0, host_0, mplex_conn_1, host_1 conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return conn_0, swarms[0], conn_1, swarms[1]
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
is_secure, muxer_opt=muxer_opt
)
return conn_0.muxed_conn, swarm_0, conn_1.muxed_conn, swarm_1
async def mplex_stream_pair_factory(
is_secure: bool
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]:
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_secure
)
stream_0 = await mplex_conn_0.open_stream()
await asyncio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
return stream_0, swarm_0, stream_1, swarm_1
async def net_stream_pair_factory( async def net_stream_pair_factory(

View File

@ -1,351 +0,0 @@
"""
Test Notify and Notifee by ensuring that the proper events get
called, and that the stream passed into opened_stream is correct
Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening
TODO: Add tests for closed_stream disconnected, listen_close when those
features are implemented in swarm
"""
import multiaddr
import pytest
from libp2p import initialize_default_swarm, new_node
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost
from libp2p.network.notifee_interface import INotifee
from tests.constants import MAX_READ_LEN
from tests.utils import perform_two_host_set_up
ACK = "ack:"
class MyNotifee(INotifee):
def __init__(self, events, val_to_append_to_event):
self.events = events
self.val_to_append_to_event = val_to_append_to_event
async def opened_stream(self, network, stream):
self.events.append(["opened_stream" + self.val_to_append_to_event, stream])
async def closed_stream(self, network, stream):
pass
async def connected(self, network, conn):
self.events.append(["connected" + self.val_to_append_to_event, conn])
async def disconnected(self, network, conn):
pass
async def listen(self, network, _multiaddr):
self.events.append(["listened" + self.val_to_append_to_event, _multiaddr])
async def listen_close(self, network, _multiaddr):
pass
class InvalidNotifee:
def __init__(self):
pass
async def opened_stream(self):
assert False
async def closed_stream(self):
assert False
async def connected(self):
assert False
async def disconnected(self):
assert False
async def listen(self):
assert False
@pytest.mark.asyncio
async def test_one_notifier():
node_a, node_b = await perform_two_host_set_up()
# Add notifee for node_a
events = []
assert node_a.get_network().notify(MyNotifee(events, "0"))
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in MyNotifee obj
# and that stream passed into opened_stream matches the stream created on
# node_a
assert events == [["connected0", stream.mplex_conn], ["opened_stream0", stream]]
messages = ["hello", "hello"]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_one_notifier_on_two_nodes():
events_b = []
messages = ["hello", "hello"]
async def my_stream_handler(stream):
# Ensure the connected and opened_stream events were hit in Notifee obj
# and that the stream passed into opened_stream matches the stream created on
# node_b
assert events_b == [
["connectedb", stream.mplex_conn],
["opened_streamb", stream],
]
for message in messages:
read_string = (await stream.read(len(message))).decode()
resp = ACK + read_string
await stream.write(resp.encode())
node_a, node_b = await perform_two_host_set_up(my_stream_handler)
# Add notifee for node_a
events_a = []
assert node_a.get_network().notify(MyNotifee(events_a, "a"))
# Add notifee for node_b
assert node_b.get_network().notify(MyNotifee(events_b, "b"))
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in MyNotifee obj
# and that stream passed into opened_stream matches the stream created on
# node_a
assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_one_notifier_on_two_nodes_with_listen():
events_b = []
messages = ["hello", "hello"]
node_a_key_pair = create_new_key_pair()
node_a_transport_opt = ["/ip4/127.0.0.1/tcp/0"]
node_a = await new_node(node_a_key_pair, transport_opt=node_a_transport_opt)
await node_a.get_network().listen(multiaddr.Multiaddr(node_a_transport_opt[0]))
# Set up node_b swarm to pass into host
node_b_key_pair = create_new_key_pair()
node_b_transport_opt = ["/ip4/127.0.0.1/tcp/0"]
node_b_multiaddr = multiaddr.Multiaddr(node_b_transport_opt[0])
node_b_swarm = initialize_default_swarm(
node_b_key_pair, transport_opt=node_b_transport_opt
)
node_b = BasicHost(node_b_swarm)
async def my_stream_handler(stream):
# Ensure the listened, connected and opened_stream events were hit in Notifee obj
# and that the stream passed into opened_stream matches the stream created on
# node_b
assert events_b == [
["listenedb", node_b_multiaddr],
["connectedb", stream.mplex_conn],
["opened_streamb", stream],
]
for message in messages:
read_string = (await stream.read(len(message))).decode()
resp = ACK + read_string
await stream.write(resp.encode())
# Add notifee for node_a
events_a = []
assert node_a.get_network().notify(MyNotifee(events_a, "a"))
# Add notifee for node_b
assert node_b.get_network().notify(MyNotifee(events_b, "b"))
# start listen on node_b_swarm
await node_b.get_network().listen(node_b_multiaddr)
node_b.set_stream_handler("/echo/1.0.0", my_stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in MyNotifee obj
# and that stream passed into opened_stream matches the stream created on
# node_a
assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_two_notifiers():
node_a, node_b = await perform_two_host_set_up()
# Add notifee for node_a
events0 = []
assert node_a.get_network().notify(MyNotifee(events0, "0"))
events1 = []
assert node_a.get_network().notify(MyNotifee(events1, "1"))
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in both Notifee objs
# and that the stream passed into opened_stream matches the stream created on
# node_a
assert events0 == [["connected0", stream.mplex_conn], ["opened_stream0", stream]]
assert events1 == [["connected1", stream.mplex_conn], ["opened_stream1", stream]]
messages = ["hello", "hello"]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_ten_notifiers():
num_notifiers = 10
node_a, node_b = await perform_two_host_set_up()
# Add notifee for node_a
events_lst = []
for i in range(num_notifiers):
events_lst.append([])
assert node_a.get_network().notify(MyNotifee(events_lst[i], str(i)))
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in both Notifee objs
# and that the stream passed into opened_stream matches the stream created on
# node_a
for i in range(num_notifiers):
assert events_lst[i] == [
["connected" + str(i), stream.mplex_conn],
["opened_stream" + str(i), stream],
]
messages = ["hello", "hello"]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_ten_notifiers_on_two_nodes():
num_notifiers = 10
events_lst_b = []
async def my_stream_handler(stream):
# Ensure the connected and opened_stream events were hit in all Notifee objs
# and that the stream passed into opened_stream matches the stream created on
# node_b
for i in range(num_notifiers):
assert events_lst_b[i] == [
["connectedb" + str(i), stream.mplex_conn],
["opened_streamb" + str(i), stream],
]
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
resp = ACK + read_string
await stream.write(resp.encode())
node_a, node_b = await perform_two_host_set_up(my_stream_handler)
# Add notifee for node_a and node_b
events_lst_a = []
for i in range(num_notifiers):
events_lst_a.append([])
events_lst_b.append([])
assert node_a.get_network().notify(MyNotifee(events_lst_a[i], "a" + str(i)))
assert node_b.get_network().notify(MyNotifee(events_lst_b[i], "b" + str(i)))
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# Ensure the connected and opened_stream events were hit in all Notifee objs
# and that the stream passed into opened_stream matches the stream created on
# node_a
for i in range(num_notifiers):
assert events_lst_a[i] == [
["connecteda" + str(i), stream.mplex_conn],
["opened_streama" + str(i), stream],
]
messages = ["hello", "hello"]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_invalid_notifee():
num_notifiers = 10
node_a, node_b = await perform_two_host_set_up()
# Add notifee for node_a
events_lst = []
for _ in range(num_notifiers):
events_lst.append([])
assert not node_a.get_network().notify(InvalidNotifee())
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
# If this point is reached, this implies that the InvalidNotifee instance
# did not assert false, i.e. no functions of InvalidNotifee were called (which is correct
# given that InvalidNotifee should not have been added as a notifee)
messages = ["hello", "hello"]
for message in messages:
expected_resp = ACK + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Success, terminate pending tasks.

View File

@ -2,7 +2,11 @@ import asyncio
import pytest import pytest
from tests.factories import net_stream_pair_factory, swarm_pair_factory from tests.factories import (
net_stream_pair_factory,
swarm_conn_pair_factory,
swarm_pair_factory,
)
@pytest.fixture @pytest.fixture
@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure):
yield swarm_0, swarm_1 yield swarm_0, swarm_1
finally: finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture
async def swarm_conn_pair(is_host_secure):
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
try:
yield conn_0, conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -7,9 +7,6 @@ from tests.constants import MAX_READ_LEN
DATA = b"data_123" DATA = b"data_123"
# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its
# own file, after `generic_protocol_handler` is refactored out of `Mplex`.
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_net_stream_read_write(net_stream_pair): async def test_net_stream_read_write(net_stream_pair):
@ -56,11 +53,9 @@ async def test_net_stream_read_until_eof(net_stream_pair):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed(net_stream_pair): async def test_net_stream_read_after_remote_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair
assert not stream_1.muxed_stream.event_remote_closed.is_set()
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert stream_1.muxed_stream.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(StreamEOF): with pytest.raises(StreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)

View File

@ -0,0 +1,112 @@
"""
Test Notify and Notifee by ensuring that the proper events get
called, and that the stream passed into opened_stream is correct
Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
"""
import asyncio
import enum
import pytest
from libp2p.network.notifee_interface import INotifee
from tests.configs import LISTEN_MADDR
from tests.factories import SwarmFactory
from tests.utils import connect_swarm
class Event(enum.Enum):
OpenedStream = 0
ClosedStream = 1 # Not implemented
Connected = 2
Disconnected = 3
Listen = 4
ListenClose = 5 # Not implemented
class MyNotifee(INotifee):
def __init__(self, events):
self.events = events
async def opened_stream(self, network, stream):
self.events.append(Event.OpenedStream)
async def closed_stream(self, network, stream):
# TODO: It is not implemented yet.
pass
async def connected(self, network, conn):
self.events.append(Event.Connected)
async def disconnected(self, network, conn):
self.events.append(Event.Disconnected)
async def listen(self, network, _multiaddr):
self.events.append(Event.Listen)
async def listen_close(self, network, _multiaddr):
# TODO: It is not implemented yet.
pass
@pytest.mark.asyncio
async def test_notify(is_host_secure):
swarms = [SwarmFactory(is_host_secure) for _ in range(2)]
events_0_0 = []
events_1_0 = []
events_0_without_listen = []
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
# Listen
await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms])
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
# Connected
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
# Disconnected
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# Connected again, but different direction.
await connect_swarm(swarms[1], swarms[0])
await asyncio.sleep(0.01)
# Disconnected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
expected_events_without_listen = [
Event.Connected,
Event.OpenedStream,
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
assert events_0_0 == expected_events
assert events_1_0 == expected_events
assert events_0_without_listen == expected_events_without_listen
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])

View File

@ -0,0 +1,45 @@
import asyncio
import pytest
@pytest.mark.asyncio
async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
await conn_0.close()
await asyncio.sleep(0.01)
assert conn_0.event_closed.is_set()
assert conn_1.event_closed.is_set()
assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values()
@pytest.mark.asyncio
async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert len(await conn_0.get_streams()) == 0
assert len(await conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 1
assert len(await conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 2
assert len(await conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0)
assert len(await conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1)

View File

@ -233,7 +233,7 @@ class FakeNetStream:
class FakeMplexConn(NamedTuple): class FakeMplexConn(NamedTuple):
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
mplex_conn = FakeMplexConn() muxed_conn = FakeMplexConn()
def __init__(self) -> None: def __init__(self) -> None:
self._queue = asyncio.Queue() self._queue = asyncio.Queue()

View File

@ -53,8 +53,8 @@ async def perform_simple_test(
node2_conn = node2.get_network().connections[peer_id_for_node(node1)] node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion # Perform assertion
assertion_func(node1_conn.conn.secured_conn) assertion_func(node1_conn.muxed_conn.secured_conn)
assertion_func(node2_conn.conn.secured_conn) assertion_func(node2_conn.muxed_conn.secured_conn)
# Success, terminate pending tasks. # Success, terminate pending tasks.

View File

View File

@ -0,0 +1,29 @@
import asyncio
import pytest
from tests.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
@pytest.fixture
async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure
)
assert mplex_conn_0.initiator
assert not mplex_conn_1.initiator
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture
async def mplex_stream_pair(is_host_secure):
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory(
is_host_secure
)
try:
yield mplex_stream_0, mplex_stream_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,50 @@
import asyncio
import pytest
@pytest.mark.asyncio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair
assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0
assert not conn_0.event_shutting_down.is_set()
assert not conn_1.event_shutting_down.is_set()
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
# Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream()
await asyncio.sleep(0.01)
assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1
# Test: From another side.
stream_1 = await conn_1.open_stream()
await asyncio.sleep(0.01)
assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2
# Close from one side.
await conn_0.close()
# Sleep for a while for both side to handle `close`.
await asyncio.sleep(0.01)
# Test: Both side is closed.
assert conn_0.event_shutting_down.is_set()
assert conn_0.event_closed.is_set()
assert conn_1.event_shutting_down.is_set()
assert conn_1.event_closed.is_set()
# Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set()
assert stream_0.event_local_closed.is_set()
assert conn_0.streams is None
# Test: All streams on the other side are also closed.
assert stream_1.event_remote_closed.is_set()
assert stream_1.event_reset.is_set()
assert stream_1.event_local_closed.is_set()
assert conn_1.streams is None
# Test: No effect to close more than once between two side.
await conn_0.close()
await conn_1.close()

View File

@ -0,0 +1,182 @@
import asyncio
import pytest
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from tests.constants import MAX_READ_LEN
DATA = b"data_123"
@pytest.mark.asyncio
async def test_mplex_stream_read_write(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = mplex_stream_pair
async def read_until_eof():
read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof())
expected_data = bytearray()
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
@pytest.mark.asyncio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA)
await stream_0.close()
await asyncio.sleep(0.01)
assert stream_1.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
with pytest.raises(MplexStreamReset):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.close()
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_1.reset()
await asyncio.sleep(0.01)
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
async def test_mplex_stream_both_close(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Flags are not set initially.
assert not stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set()
assert not stream_0.event_remote_closed.is_set()
assert not stream_1.event_remote_closed.is_set()
# Streams are present in their `mplex_conn`.
assert stream_0 in stream_0.muxed_conn.streams.values()
assert stream_1 in stream_1.muxed_conn.streams.values()
# Test: Close one side.
await stream_0.close()
await asyncio.sleep(0.01)
assert stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set()
assert not stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are still present in their `mplex_conn`.
assert stream_0 in stream_0.muxed_conn.streams.values()
assert stream_1 in stream_1.muxed_conn.streams.values()
# Test: Close the other side.
await stream_1.close()
await asyncio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set()
assert stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are removed from their `mplex_conn`.
assert stream_0 not in stream_0.muxed_conn.streams.values()
assert stream_1 not in stream_1.muxed_conn.streams.values()
# Test: Reset after both close.
await stream_0.reset()
@pytest.mark.asyncio
async def test_mplex_stream_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
await asyncio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set()
assert stream_0.event_remote_closed.is_set()
assert stream_1.event_remote_closed.is_set()
# Streams are removed from their `mplex_conn`.
assert stream_0 not in stream_0.muxed_conn.streams.values()
assert stream_1 not in stream_1.muxed_conn.streams.values()
# `close` should do nothing.
await stream_0.close()
await stream_1.close()
# `reset` should do nothing as well.
await stream_0.reset()
await stream_1.reset()

View File

@ -22,6 +22,5 @@ async def test_connect(hosts, p2pds):
assert len(host.get_network().connections) == 1 assert len(host.get_network().connections) == 1
# Test: `disconnect` from Go # Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id()) await p2pd.control.disconnect(host.get_id())
# FIXME: Failed to handle disconnect
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert len(host.get_network().connections) == 0 assert len(host.get_network().connections) == 0