Merge pull request #304 from mhchia/fix/detection-of-close
Detect closed `Mplex`
This commit is contained in:
commit
4a838033ff
|
@ -30,6 +30,10 @@ logger.setLevel(logging.DEBUG)
|
|||
|
||||
|
||||
class BasicHost(IHost):
|
||||
"""
|
||||
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
|
||||
on a stream with multistream-select right after a stream is initialized.
|
||||
"""
|
||||
|
||||
_network: INetwork
|
||||
_router: KadmeliaPeerRouter
|
||||
|
@ -38,7 +42,6 @@ class BasicHost(IHost):
|
|||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
|
||||
# default options constructor
|
||||
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
||||
self._network = network
|
||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||
|
@ -76,6 +79,7 @@ class BasicHost(IHost):
|
|||
"""
|
||||
:return: all the multiaddr addresses this host is listening to
|
||||
"""
|
||||
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
|
||||
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
||||
|
||||
addrs: List[multiaddr.Multiaddr] = []
|
||||
|
@ -94,8 +98,6 @@ class BasicHost(IHost):
|
|||
"""
|
||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||
|
||||
# `protocol_ids` can be a list of `protocol_id`
|
||||
# stream will decide which `protocol_id` to run on
|
||||
async def new_stream(
|
||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||
) -> INetStream:
|
||||
|
|
|
@ -37,8 +37,8 @@ class RawConnection(IRawConnection):
|
|||
async with self._drain_lock:
|
||||
try:
|
||||
await self.writer.drain()
|
||||
except ConnectionResetError:
|
||||
raise RawConnError()
|
||||
except ConnectionResetError as error:
|
||||
raise RawConnError(error)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
|
|||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.network.stream.net_stream import NetStream
|
||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.network.swarm import Swarm # noqa: F401
|
||||
|
@ -34,17 +35,28 @@ class SwarmConn(INetConn):
|
|||
if self.event_closed.is_set():
|
||||
return
|
||||
self.event_closed.set()
|
||||
self.swarm.remove_conn(self)
|
||||
|
||||
await self.conn.close()
|
||||
|
||||
# This is just for cleaning up state. The connection has already been closed.
|
||||
# We *could* optimize this but it really isn't worth it.
|
||||
for stream in self.streams:
|
||||
await stream.reset()
|
||||
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
||||
asyncio.ensure_future(self._notify_disconnected())
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
# TODO: Reset streams for local.
|
||||
# TODO: Notify closed.
|
||||
|
||||
async def _handle_new_streams(self) -> None:
|
||||
# TODO: Break the loop when anything wrong in the connection.
|
||||
while True:
|
||||
try:
|
||||
stream = await self.conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
# If there is anything wrong in the MuxedConn,
|
||||
# we should break the loop and close the connection.
|
||||
break
|
||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
await self.run_task(self._handle_muxed_stream(stream))
|
||||
|
||||
|
@ -57,11 +69,16 @@ class SwarmConn(INetConn):
|
|||
|
||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||
net_stream = NetStream(muxed_stream)
|
||||
self.streams.add(net_stream)
|
||||
# Call notifiers since event occurred
|
||||
for notifee in self.swarm.notifees:
|
||||
await notifee.opened_stream(self.swarm, net_stream)
|
||||
return net_stream
|
||||
|
||||
async def _notify_disconnected(self) -> None:
|
||||
for notifee in self.swarm.notifees:
|
||||
await notifee.disconnected(self.swarm, self.conn)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.run_task(self._handle_new_streams())
|
||||
|
||||
|
|
|
@ -272,13 +272,18 @@ class Swarm(INetwork):
|
|||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
# TODO: Should be changed to close multisple connections,
|
||||
# if we have several connections per peer in the future.
|
||||
connection = self.connections[peer_id]
|
||||
del self.connections[peer_id]
|
||||
await connection.close()
|
||||
|
||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||
|
||||
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
|
||||
"""
|
||||
Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected",
|
||||
and start to monitor the connection for its new streams and disconnection.
|
||||
"""
|
||||
swarm_conn = SwarmConn(muxed_conn, self)
|
||||
# Store muxed_conn with peer id
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
|
@ -288,3 +293,14 @@ class Swarm(INetwork):
|
|||
await notifee.connected(self, muxed_conn)
|
||||
await swarm_conn.start()
|
||||
return swarm_conn
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
"""
|
||||
Simply remove the connection from Swarm's records, without closing the connection.
|
||||
"""
|
||||
peer_id = swarm_conn.conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
# TODO: Should be changed to remove the exact connection,
|
||||
# if we have several connections per peer in the future.
|
||||
del self.connections[peer_id]
|
||||
|
|
|
@ -5,7 +5,7 @@ class MuxedConnError(BaseLibp2pError):
|
|||
pass
|
||||
|
||||
|
||||
class MuxedConnShutdown(MuxedConnError):
|
||||
class MuxedConnUnavailable(MuxedConnError):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedConnError,
|
||||
MuxedConnShutdown,
|
||||
MuxedConnUnavailable,
|
||||
MuxedStreamClosed,
|
||||
MuxedStreamEOF,
|
||||
MuxedStreamReset,
|
||||
|
@ -11,7 +11,7 @@ class MplexError(MuxedConnError):
|
|||
pass
|
||||
|
||||
|
||||
class MplexShutdown(MuxedConnShutdown):
|
||||
class MplexUnavailable(MuxedConnUnavailable):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
from typing import Any # noqa: F401
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
from libp2p.network.connection.exceptions import RawConnError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||
|
@ -17,6 +18,7 @@ from libp2p.utils import (
|
|||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
from .exceptions import MplexUnavailable
|
||||
from .mplex_stream import MplexStream
|
||||
|
||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||
|
@ -36,7 +38,8 @@ class Mplex(IMuxedConn):
|
|||
streams: Dict[StreamID, MplexStream]
|
||||
streams_lock: asyncio.Lock
|
||||
new_stream_queue: "asyncio.Queue[IMuxedStream]"
|
||||
shutdown: asyncio.Event
|
||||
event_shutting_down: asyncio.Event
|
||||
event_closed: asyncio.Event
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
|
@ -60,7 +63,8 @@ class Mplex(IMuxedConn):
|
|||
self.streams = {}
|
||||
self.streams_lock = asyncio.Lock()
|
||||
self.new_stream_queue = asyncio.Queue()
|
||||
self.shutdown = asyncio.Event()
|
||||
self.event_shutting_down = asyncio.Event()
|
||||
self.event_closed = asyncio.Event()
|
||||
|
||||
self._tasks = []
|
||||
|
||||
|
@ -75,16 +79,20 @@ class Mplex(IMuxedConn):
|
|||
"""
|
||||
close the stream muxer and underlying secured connection
|
||||
"""
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
if self.event_shutting_down.is_set():
|
||||
return
|
||||
# Set the `event_shutting_down`, to allow graceful shutdown.
|
||||
self.event_shutting_down.set()
|
||||
await self.secured_conn.close()
|
||||
# Blocked until `close` is finally set.
|
||||
await self.event_closed.wait()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
check connection is fully closed
|
||||
:return: true if successful
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self.event_closed.is_set()
|
||||
|
||||
def _get_next_channel_id(self) -> int:
|
||||
"""
|
||||
|
@ -114,11 +122,29 @@ class Mplex(IMuxedConn):
|
|||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||
return stream
|
||||
|
||||
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
|
||||
task_coro = asyncio.ensure_future(coro)
|
||||
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
|
||||
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
[task_coro, task_wait_closed, task_wait_shutting_down],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
if task_wait_closed in done:
|
||||
raise MplexUnavailable("Mplex is closed")
|
||||
if task_wait_shutting_down in done:
|
||||
raise MplexUnavailable("Mplex is shutting down")
|
||||
return task_coro.result()
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
"""
|
||||
accepts a muxed stream opened by the other end
|
||||
"""
|
||||
return await self.new_stream_queue.get()
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.get()
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||
|
@ -137,7 +163,9 @@ class Mplex(IMuxedConn):
|
|||
|
||||
_bytes = header + encode_varint_prefixed(data)
|
||||
|
||||
return await self.write_to_stream(_bytes)
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.write_to_stream(_bytes)
|
||||
)
|
||||
|
||||
async def write_to_stream(self, _bytes: bytes) -> int:
|
||||
"""
|
||||
|
@ -152,55 +180,112 @@ class Mplex(IMuxedConn):
|
|||
"""
|
||||
Read a message off of the secured connection and add it to the corresponding message buffer
|
||||
"""
|
||||
# TODO Deal with other types of messages using flag (currently _)
|
||||
|
||||
while True:
|
||||
channel_id, flag, message = await self.read_message()
|
||||
if channel_id is not None and flag is not None and message is not None:
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
is_stream_id_seen: bool
|
||||
stream: MplexStream
|
||||
async with self.streams_lock:
|
||||
is_stream_id_seen = stream_id in self.streams
|
||||
if is_stream_id_seen:
|
||||
stream = self.streams[stream_id]
|
||||
# Other consequent stream message should wait until the stream get accepted
|
||||
# TODO: Handle more tags, and refactor `HeaderTags`
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
if is_stream_id_seen:
|
||||
# `NewStream` for the same id is received twice...
|
||||
# TODO: Shutdown
|
||||
pass
|
||||
mplex_stream = await self._initialize_stream(
|
||||
stream_id, message.decode()
|
||||
try:
|
||||
await self._handle_incoming_message()
|
||||
except MplexUnavailable:
|
||||
break
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
# If we enter here, it means this connection is shutting down.
|
||||
# We should clean things up.
|
||||
await self._cleanup()
|
||||
|
||||
async def read_message(self) -> Tuple[int, int, bytes]:
|
||||
"""
|
||||
Read a single message off of the secured connection
|
||||
:return: stream_id, flag, message contents
|
||||
"""
|
||||
|
||||
# FIXME: No timeout is used in Go implementation.
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||
message = await asyncio.wait_for(
|
||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||
)
|
||||
# TODO: Check if `self` is shutdown.
|
||||
await self.new_stream_queue.put(mplex_stream)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read messages correctly from the underlying connection"
|
||||
) from error
|
||||
except asyncio.TimeoutError as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read more message body within the timeout"
|
||||
) from error
|
||||
|
||||
flag = header & 0x07
|
||||
channel_id = header >> 3
|
||||
|
||||
return channel_id, flag, message
|
||||
|
||||
async def _handle_incoming_message(self) -> None:
|
||||
"""
|
||||
Read and handle a new incoming message.
|
||||
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||
"""
|
||||
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
|
||||
self.read_message()
|
||||
)
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
await self._handle_new_stream(stream_id, message)
|
||||
elif flag in (
|
||||
HeaderTags.MessageInitiator.value,
|
||||
HeaderTags.MessageReceiver.value,
|
||||
):
|
||||
if not is_stream_id_seen:
|
||||
await self._handle_message(stream_id, message)
|
||||
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
|
||||
await self._handle_close(stream_id)
|
||||
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
|
||||
await self._handle_reset(stream_id)
|
||||
else:
|
||||
# Receives messages with an unknown flag
|
||||
# TODO: logging
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
await stream.reset()
|
||||
|
||||
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
# `NewStream` for the same id is received twice...
|
||||
raise MplexUnavailable(
|
||||
f"received NewStream message for existing stream: {stream_id}"
|
||||
)
|
||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.put(mplex_stream)
|
||||
)
|
||||
|
||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
# We receive a message of the stream `stream_id` which is not accepted
|
||||
# before. It is abnormal. Possibly disconnect?
|
||||
# TODO: Warn and emit logs about this.
|
||||
continue
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
async with stream.close_lock:
|
||||
if stream.event_remote_closed.is_set():
|
||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||
continue
|
||||
await stream.incoming_data.put(message)
|
||||
elif flag in (
|
||||
HeaderTags.CloseInitiator.value,
|
||||
HeaderTags.CloseReceiver.value,
|
||||
):
|
||||
if not is_stream_id_seen:
|
||||
continue
|
||||
return
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
stream.incoming_data.put(message)
|
||||
)
|
||||
|
||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
# Ignore unmatched messages for now.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
# NOTE: If remote is already closed, then return: Technically a bug
|
||||
# on the other side. We should consider killing the connection.
|
||||
async with stream.close_lock:
|
||||
if stream.event_remote_closed.is_set():
|
||||
continue
|
||||
return
|
||||
is_local_closed: bool
|
||||
async with stream.close_lock:
|
||||
stream.event_remote_closed.set()
|
||||
|
@ -210,16 +295,16 @@ class Mplex(IMuxedConn):
|
|||
if is_local_closed:
|
||||
async with self.streams_lock:
|
||||
del self.streams[stream_id]
|
||||
elif flag in (
|
||||
HeaderTags.ResetInitiator.value,
|
||||
HeaderTags.ResetReceiver.value,
|
||||
):
|
||||
if not is_stream_id_seen:
|
||||
|
||||
async def _handle_reset(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
# This is *ok*. We forget the stream on reset.
|
||||
continue
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
# TODO: Why? Only if remote is not closed before then reset.
|
||||
stream.event_reset.set()
|
||||
|
||||
stream.event_remote_closed.set()
|
||||
|
@ -228,36 +313,16 @@ class Mplex(IMuxedConn):
|
|||
stream.event_local_closed.set()
|
||||
async with self.streams_lock:
|
||||
del self.streams[stream_id]
|
||||
else:
|
||||
# TODO: logging
|
||||
if is_stream_id_seen:
|
||||
await stream.reset()
|
||||
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def read_message(self) -> Tuple[int, int, bytes]:
|
||||
"""
|
||||
Read a single message off of the secured connection
|
||||
:return: stream_id, flag, message contents
|
||||
"""
|
||||
|
||||
# FIXME: No timeout is used in Go implementation.
|
||||
# Timeout is set to a relatively small value to alleviate wait time to exit
|
||||
# loop in handle_incoming
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||
except ParseError:
|
||||
return None, None, None
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||
)
|
||||
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
|
||||
# TODO: Investigate what we should do if time is out.
|
||||
return None, None, None
|
||||
|
||||
flag = header & 0x07
|
||||
channel_id = header >> 3
|
||||
|
||||
return channel_id, flag, message
|
||||
async def _cleanup(self) -> None:
|
||||
if not self.event_shutting_down.is_set():
|
||||
self.event_shutting_down.set()
|
||||
async with self.streams_lock:
|
||||
for stream in self.streams.values():
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
stream.event_remote_closed.set()
|
||||
stream.event_reset.set()
|
||||
stream.event_local_closed.set()
|
||||
self.streams = None
|
||||
self.event_closed.set()
|
||||
|
|
|
@ -204,6 +204,10 @@ class MplexStream(IMuxedStream):
|
|||
self.event_remote_closed.set()
|
||||
|
||||
async with self.mplex_conn.streams_lock:
|
||||
if (
|
||||
self.mplex_conn.streams is not None
|
||||
and self.stream_id in self.mplex_conn.streams
|
||||
):
|
||||
del self.mplex_conn.streams[self.stream_id]
|
||||
|
||||
# TODO deadline not in use
|
||||
|
|
|
@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
|||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("TODO: better exception msg: Integer is too large...")
|
||||
|
||||
byte = await reader.read(1)
|
||||
|
||||
try:
|
||||
byte = await read_exactly(reader, 1)
|
||||
value = byte[0]
|
||||
except IndexError:
|
||||
raise ParseError(
|
||||
"Unexpected end of stream while parsing LEB128 encoded integer"
|
||||
)
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
|
|
|
@ -6,15 +6,14 @@ import factory
|
|||
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
||||
from libp2p.crypto.keys import KeyPair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||
from libp2p.typing import TProtocol
|
||||
from tests.configs import LISTEN_MADDR
|
||||
from tests.pubsub.configs import (
|
||||
|
@ -22,7 +21,7 @@ from tests.pubsub.configs import (
|
|||
GOSSIPSUB_PARAMS,
|
||||
GOSSIPSUB_PROTOCOL_ID,
|
||||
)
|
||||
from tests.utils import connect
|
||||
from tests.utils import connect, connect_swarm
|
||||
|
||||
|
||||
def security_transport_factory(
|
||||
|
@ -34,12 +33,31 @@ def security_transport_factory(
|
|||
return {secio.ID: secio.Transport(key_pair)}
|
||||
|
||||
|
||||
def swarm_factory(is_secure: bool):
|
||||
def SwarmFactory(is_secure: bool) -> Swarm:
|
||||
key_pair = generate_new_rsa_identity()
|
||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
||||
sec_opt = security_transport_factory(False, key_pair)
|
||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||
|
||||
|
||||
class ListeningSwarmFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = Swarm
|
||||
|
||||
@classmethod
|
||||
async def create_and_listen(cls, is_secure: bool) -> Swarm:
|
||||
swarm = SwarmFactory(is_secure)
|
||||
await swarm.listen(LISTEN_MADDR)
|
||||
return swarm
|
||||
|
||||
@classmethod
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> Tuple[Swarm, ...]:
|
||||
return await asyncio.gather(
|
||||
*[cls.create_and_listen(is_secure) for _ in range(number)]
|
||||
)
|
||||
|
||||
|
||||
class HostFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = BasicHost
|
||||
|
@ -47,13 +65,19 @@ class HostFactory(factory.Factory):
|
|||
class Params:
|
||||
is_secure = False
|
||||
|
||||
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
||||
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
|
||||
|
||||
@classmethod
|
||||
async def create_and_listen(cls) -> IHost:
|
||||
host = cls()
|
||||
await host.get_network().listen(LISTEN_MADDR)
|
||||
return host
|
||||
async def create_and_listen(cls, is_secure: bool) -> BasicHost:
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1)
|
||||
return BasicHost(swarms[0])
|
||||
|
||||
@classmethod
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> Tuple[BasicHost, ...]:
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number)
|
||||
return tuple(BasicHost(swarm) for swarm in range(swarms))
|
||||
|
||||
|
||||
class FloodsubFactory(factory.Factory):
|
||||
|
@ -87,24 +111,33 @@ class PubsubFactory(factory.Factory):
|
|||
cache_size = None
|
||||
|
||||
|
||||
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
|
||||
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
return swarms[0], swarms[1]
|
||||
|
||||
|
||||
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||
hosts = await asyncio.gather(
|
||||
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
|
||||
*[
|
||||
HostFactory.create_and_listen(is_secure),
|
||||
HostFactory.create_and_listen(is_secure),
|
||||
]
|
||||
)
|
||||
await connect(hosts[0], hosts[1])
|
||||
return hosts[0], hosts[1]
|
||||
|
||||
|
||||
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||
host_0, host_1 = await host_pair_factory()
|
||||
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||
return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||
# host_0, host_1 = await host_pair_factory()
|
||||
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||
# return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||
|
||||
|
||||
async def net_stream_pair_factory() -> Tuple[
|
||||
INetStream, BasicHost, INetStream, BasicHost
|
||||
]:
|
||||
async def net_stream_pair_factory(
|
||||
is_secure: bool
|
||||
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||
protocol_id = "/example/id/1"
|
||||
|
||||
stream_1: INetStream
|
||||
|
@ -114,7 +147,7 @@ async def net_stream_pair_factory() -> Tuple[
|
|||
nonlocal stream_1
|
||||
stream_1 = stream
|
||||
|
||||
host_0, host_1 = await host_pair_factory()
|
||||
host_0, host_1 = await host_pair_factory(is_secure)
|
||||
host_1.set_stream_handler(protocol_id, handler)
|
||||
|
||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import connect
|
||||
|
@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds):
|
|||
# Test: `disconnect` from Go
|
||||
await p2pd.control.disconnect(host.get_id())
|
||||
# FIXME: Failed to handle disconnect
|
||||
# assert len(host.get_network().connections) == 0
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(host.get_network().connections) == 0
|
||||
|
|
|
@ -2,13 +2,22 @@ import asyncio
|
|||
|
||||
import pytest
|
||||
|
||||
from tests.factories import net_stream_pair_factory
|
||||
from tests.factories import net_stream_pair_factory, swarm_pair_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def net_stream_pair():
|
||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
|
||||
async def net_stream_pair(is_host_secure):
|
||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield stream_0, stream_1
|
||||
finally:
|
||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_pair(is_host_secure):
|
||||
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield swarm_0, swarm_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
|
|
93
tests/network/test_swarm.py
Normal file
93
tests/network/test_swarm.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from tests.factories import ListeningSwarmFactory
|
||||
from tests.utils import connect_swarm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_dial_peer(is_host_secure):
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
# Test: No addr found.
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: len(addr) in the peerstore is 0.
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_close_peer(is_host_secure):
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
# 0 <> 1 <> 2
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
await connect_swarm(swarms[1], swarms[2])
|
||||
|
||||
# peer 1 closes peer 0
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 <> 2
|
||||
assert len(swarms[0].connections) == 0
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[2].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
|
||||
# peer 1 is closed by peer 2
|
||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# 0 <> 1 2
|
||||
assert (
|
||||
len(swarms[0].connections) == 1
|
||||
and swarms[1].get_peer_id() in swarms[0].connections
|
||||
)
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[0].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 0 closes peer 1
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
|
@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
|
|||
from tests.constants import MAX_READ_LEN
|
||||
|
||||
|
||||
async def connect_swarm(swarm_0, swarm_1):
|
||||
peer_id = swarm_1.get_peer_id()
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarm_1.listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||
await swarm_0.dial_peer(peer_id)
|
||||
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||
|
||||
|
||||
async def connect(node1, node2):
|
||||
"""
|
||||
Connect node1 to node2
|
||||
|
|
Loading…
Reference in New Issue
Block a user