Merge pull request #304 from mhchia/fix/detection-of-close

Detect closed `Mplex`
This commit is contained in:
Kevin Mai-Husan Chia 2019-09-21 18:28:16 +08:00 committed by GitHub
commit 4a838033ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 392 additions and 143 deletions

View File

@ -30,6 +30,10 @@ logger.setLevel(logging.DEBUG)
class BasicHost(IHost): class BasicHost(IHost):
"""
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
on a stream with multistream-select right after a stream is initialized.
"""
_network: INetwork _network: INetwork
_router: KadmeliaPeerRouter _router: KadmeliaPeerRouter
@ -38,7 +42,6 @@ class BasicHost(IHost):
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
# default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler) self._network.set_stream_handler(self._swarm_stream_handler)
@ -76,6 +79,7 @@ class BasicHost(IHost):
""" """
:return: all the multiaddr addresses this host is listening to :return: all the multiaddr addresses this host is listening to
""" """
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
addrs: List[multiaddr.Multiaddr] = [] addrs: List[multiaddr.Multiaddr] = []
@ -94,8 +98,6 @@ class BasicHost(IHost):
""" """
self.multiselect.add_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream: ) -> INetStream:

View File

@ -37,8 +37,8 @@ class RawConnection(IRawConnection):
async with self._drain_lock: async with self._drain_lock:
try: try:
await self.writer.drain() await self.writer.drain()
except ConnectionResetError: except ConnectionResetError as error:
raise RawConnError() raise RawConnError(error)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream from libp2p.network.stream.net_stream import NetStream
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.network.swarm import Swarm # noqa: F401 from libp2p.network.swarm import Swarm # noqa: F401
@ -34,17 +35,28 @@ class SwarmConn(INetConn):
if self.event_closed.is_set(): if self.event_closed.is_set():
return return
self.event_closed.set() self.event_closed.set()
self.swarm.remove_conn(self)
await self.conn.close() await self.conn.close()
# This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it.
for stream in self.streams:
await stream.reset()
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
asyncio.ensure_future(self._notify_disconnected())
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()
# TODO: Reset streams for local.
# TODO: Notify closed.
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True: while True:
try:
stream = await self.conn.accept_stream() stream = await self.conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream)) await self.run_task(self._handle_muxed_stream(stream))
@ -57,11 +69,16 @@ class SwarmConn(INetConn):
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream)
# Call notifiers since event occurred # Call notifiers since event occurred
for notifee in self.swarm.notifees: for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream) await notifee.opened_stream(self.swarm, net_stream)
return net_stream return net_stream
async def _notify_disconnected(self) -> None:
for notifee in self.swarm.notifees:
await notifee.disconnected(self.swarm, self.conn)
async def start(self) -> None: async def start(self) -> None:
await self.run_task(self._handle_new_streams()) await self.run_task(self._handle_new_streams())

View File

@ -272,13 +272,18 @@ class Swarm(INetwork):
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections: if peer_id not in self.connections:
return return
# TODO: Should be changed to close multisple connections,
# if we have several connections per peer in the future.
connection = self.connections[peer_id] connection = self.connections[peer_id]
del self.connections[peer_id]
await connection.close() await connection.close()
logger.debug("successfully close the connection to peer %s", peer_id) logger.debug("successfully close the connection to peer %s", peer_id)
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn: async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
"""
Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected",
and start to monitor the connection for its new streams and disconnection.
"""
swarm_conn = SwarmConn(muxed_conn, self) swarm_conn = SwarmConn(muxed_conn, self)
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn self.connections[muxed_conn.peer_id] = swarm_conn
@ -288,3 +293,14 @@ class Swarm(INetwork):
await notifee.connected(self, muxed_conn) await notifee.connected(self, muxed_conn)
await swarm_conn.start() await swarm_conn.start()
return swarm_conn return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None:
"""
Simply remove the connection from Swarm's records, without closing the connection.
"""
peer_id = swarm_conn.conn.peer_id
if peer_id not in self.connections:
return
# TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future.
del self.connections[peer_id]

View File

@ -5,7 +5,7 @@ class MuxedConnError(BaseLibp2pError):
pass pass
class MuxedConnShutdown(MuxedConnError): class MuxedConnUnavailable(MuxedConnError):
pass pass

View File

@ -1,6 +1,6 @@
from libp2p.stream_muxer.exceptions import ( from libp2p.stream_muxer.exceptions import (
MuxedConnError, MuxedConnError,
MuxedConnShutdown, MuxedConnUnavailable,
MuxedStreamClosed, MuxedStreamClosed,
MuxedStreamEOF, MuxedStreamEOF,
MuxedStreamReset, MuxedStreamReset,
@ -11,7 +11,7 @@ class MplexError(MuxedConnError):
pass pass
class MplexShutdown(MuxedConnShutdown): class MplexUnavailable(MuxedConnUnavailable):
pass pass

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
from typing import Any # noqa: F401 from typing import Any # noqa: F401
from typing import Dict, List, Optional, Tuple from typing import Awaitable, Dict, List, Optional, Tuple
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.connection.exceptions import RawConnError
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
@ -17,6 +18,7 @@ from libp2p.utils import (
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
from .exceptions import MplexUnavailable
from .mplex_stream import MplexStream from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
@ -36,7 +38,8 @@ class Mplex(IMuxedConn):
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]" new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event event_shutting_down: asyncio.Event
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
@ -60,7 +63,8 @@ class Mplex(IMuxedConn):
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue() self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event() self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self._tasks = [] self._tasks = []
@ -75,16 +79,20 @@ class Mplex(IMuxedConn):
""" """
close the stream muxer and underlying secured connection close the stream muxer and underlying secured connection
""" """
for task in self._tasks: if self.event_shutting_down.is_set():
task.cancel() return
# Set the `event_shutting_down`, to allow graceful shutdown.
self.event_shutting_down.set()
await self.secured_conn.close() await self.secured_conn.close()
# Blocked until `close` is finally set.
await self.event_closed.wait()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
check connection is fully closed check connection is fully closed
:return: true if successful :return: true if successful
""" """
raise NotImplementedError() return self.event_closed.is_set()
def _get_next_channel_id(self) -> int: def _get_next_channel_id(self) -> int:
""" """
@ -114,11 +122,29 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
return await self.new_stream_queue.get() return await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.get()
)
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -137,7 +163,9 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data) _bytes = header + encode_varint_prefixed(data)
return await self.write_to_stream(_bytes) return await self._wait_until_shutting_down_or_closed(
self.write_to_stream(_bytes)
)
async def write_to_stream(self, _bytes: bytes) -> int: async def write_to_stream(self, _bytes: bytes) -> int:
""" """
@ -152,55 +180,112 @@ class Mplex(IMuxedConn):
""" """
Read a message off of the secured connection and add it to the corresponding message buffer Read a message off of the secured connection and add it to the corresponding message buffer
""" """
# TODO Deal with other types of messages using flag (currently _)
while True: while True:
channel_id, flag, message = await self.read_message() try:
if channel_id is not None and flag is not None and message is not None: await self._handle_incoming_message()
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) except MplexUnavailable:
is_stream_id_seen: bool break
stream: MplexStream # Force context switch
async with self.streams_lock: await asyncio.sleep(0)
is_stream_id_seen = stream_id in self.streams # If we enter here, it means this connection is shutting down.
if is_stream_id_seen: # We should clean things up.
stream = self.streams[stream_id] await self._cleanup()
# Other consequent stream message should wait until the stream get accepted
# TODO: Handle more tags, and refactor `HeaderTags` async def read_message(self) -> Tuple[int, int, bytes]:
if flag == HeaderTags.NewStream.value: """
if is_stream_id_seen: Read a single message off of the secured connection
# `NewStream` for the same id is received twice... :return: stream_id, flag, message contents
# TODO: Shutdown """
pass
mplex_stream = await self._initialize_stream( # FIXME: No timeout is used in Go implementation.
stream_id, message.decode() try:
header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
) )
# TODO: Check if `self` is shutdown. except (ParseError, RawConnError, IncompleteReadError) as error:
await self.new_stream_queue.put(mplex_stream) raise MplexUnavailable(
"failed to read messages correctly from the underlying connection"
) from error
except asyncio.TimeoutError as error:
raise MplexUnavailable(
"failed to read more message body within the timeout"
) from error
flag = header & 0x07
channel_id = header >> 3
return channel_id, flag, message
async def _handle_incoming_message(self) -> None:
"""
Read and handle a new incoming message.
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
self.read_message()
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value:
await self._handle_new_stream(stream_id, message)
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,
): ):
if not is_stream_id_seen: await self._handle_message(stream_id, message)
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
await self._handle_close(stream_id)
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
await self._handle_reset(stream_id)
else:
# Receives messages with an unknown flag
# TODO: logging
async with self.streams_lock:
if stream_id in self.streams:
stream = self.streams[stream_id]
await stream.reset()
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
if stream_id in self.streams:
# `NewStream` for the same id is received twice...
raise MplexUnavailable(
f"received NewStream message for existing stream: {stream_id}"
)
mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.put(mplex_stream)
)
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted # We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
continue return
stream = self.streams[stream_id]
async with stream.close_lock: async with stream.close_lock:
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
continue return
await stream.incoming_data.put(message) await self._wait_until_shutting_down_or_closed(
elif flag in ( stream.incoming_data.put(message)
HeaderTags.CloseInitiator.value, )
HeaderTags.CloseReceiver.value,
): async def _handle_close(self, stream_id: StreamID) -> None:
if not is_stream_id_seen: async with self.streams_lock:
continue if stream_id not in self.streams:
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
# NOTE: If remote is already closed, then return: Technically a bug # NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection. # on the other side. We should consider killing the connection.
async with stream.close_lock: async with stream.close_lock:
if stream.event_remote_closed.is_set(): if stream.event_remote_closed.is_set():
continue return
is_local_closed: bool is_local_closed: bool
async with stream.close_lock: async with stream.close_lock:
stream.event_remote_closed.set() stream.event_remote_closed.set()
@ -210,16 +295,16 @@ class Mplex(IMuxedConn):
if is_local_closed: if is_local_closed:
async with self.streams_lock: async with self.streams_lock:
del self.streams[stream_id] del self.streams[stream_id]
elif flag in (
HeaderTags.ResetInitiator.value, async def _handle_reset(self, stream_id: StreamID) -> None:
HeaderTags.ResetReceiver.value, async with self.streams_lock:
): if stream_id not in self.streams:
if not is_stream_id_seen:
# This is *ok*. We forget the stream on reset. # This is *ok*. We forget the stream on reset.
continue return
stream = self.streams[stream_id]
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
# TODO: Why? Only if remote is not closed before then reset.
stream.event_reset.set() stream.event_reset.set()
stream.event_remote_closed.set() stream.event_remote_closed.set()
@ -228,36 +313,16 @@ class Mplex(IMuxedConn):
stream.event_local_closed.set() stream.event_local_closed.set()
async with self.streams_lock: async with self.streams_lock:
del self.streams[stream_id] del self.streams[stream_id]
else:
# TODO: logging
if is_stream_id_seen:
await stream.reset()
# Force context switch async def _cleanup(self) -> None:
await asyncio.sleep(0) if not self.event_shutting_down.is_set():
self.event_shutting_down.set()
async def read_message(self) -> Tuple[int, int, bytes]: async with self.streams_lock:
""" for stream in self.streams.values():
Read a single message off of the secured connection async with stream.close_lock:
:return: stream_id, flag, message contents if not stream.event_remote_closed.is_set():
""" stream.event_remote_closed.set()
stream.event_reset.set()
# FIXME: No timeout is used in Go implementation. stream.event_local_closed.set()
# Timeout is set to a relatively small value to alleviate wait time to exit self.streams = None
# loop in handle_incoming self.event_closed.set()
try:
header = await decode_uvarint_from_stream(self.secured_conn)
except ParseError:
return None, None, None
try:
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
# TODO: Investigate what we should do if time is out.
return None, None, None
flag = header & 0x07
channel_id = header >> 3
return channel_id, flag, message

View File

@ -204,6 +204,10 @@ class MplexStream(IMuxedStream):
self.event_remote_closed.set() self.event_remote_closed.set()
async with self.mplex_conn.streams_lock: async with self.mplex_conn.streams_lock:
if (
self.mplex_conn.streams is not None
and self.stream_id in self.mplex_conn.streams
):
del self.mplex_conn.streams[self.stream_id] del self.mplex_conn.streams[self.stream_id]
# TODO deadline not in use # TODO deadline not in use

View File

@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
if shift > SHIFT_64_BIT_MAX: if shift > SHIFT_64_BIT_MAX:
raise ParseError("TODO: better exception msg: Integer is too large...") raise ParseError("TODO: better exception msg: Integer is too large...")
byte = await reader.read(1) byte = await read_exactly(reader, 1)
try:
value = byte[0] value = byte[0]
except IndexError:
raise ParseError(
"Unexpected end of stream while parsing LEB128 encoded integer"
)
res += (value & LOW_MASK) << shift res += (value & LOW_MASK) << shift

View File

@ -6,15 +6,14 @@ import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
@ -22,7 +21,7 @@ from tests.pubsub.configs import (
GOSSIPSUB_PARAMS, GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID,
) )
from tests.utils import connect from tests.utils import connect, connect_swarm
def security_transport_factory( def security_transport_factory(
@ -34,12 +33,31 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
def swarm_factory(is_secure: bool): def SwarmFactory(is_secure: bool) -> Swarm:
key_pair = generate_new_rsa_identity() key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair) sec_opt = security_transport_factory(False, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt) return initialize_default_swarm(key_pair, sec_opt=sec_opt)
class ListeningSwarmFactory(factory.Factory):
class Meta:
model = Swarm
@classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm:
swarm = SwarmFactory(is_secure)
await swarm.listen(LISTEN_MADDR)
return swarm
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[Swarm, ...]:
return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)]
)
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
class Meta: class Meta:
model = BasicHost model = BasicHost
@ -47,13 +65,19 @@ class HostFactory(factory.Factory):
class Params: class Params:
is_secure = False is_secure = False
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
@classmethod @classmethod
async def create_and_listen(cls) -> IHost: async def create_and_listen(cls, is_secure: bool) -> BasicHost:
host = cls() swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1)
await host.get_network().listen(LISTEN_MADDR) return BasicHost(swarms[0])
return host
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number)
return tuple(BasicHost(swarm) for swarm in range(swarms))
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
@ -87,24 +111,33 @@ class PubsubFactory(factory.Factory):
cache_size = None cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather( hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()] *[
HostFactory.create_and_listen(is_secure),
HostFactory.create_and_listen(is_secure),
]
) )
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])
return hosts[0], hosts[1] return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: # async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory() # host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] # mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] # mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1 # return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[ async def net_stream_pair_factory(
INetStream, BasicHost, INetStream, BasicHost is_secure: bool
]: ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
protocol_id = "/example/id/1" protocol_id = "/example/id/1"
stream_1: INetStream stream_1: INetStream
@ -114,7 +147,7 @@ async def net_stream_pair_factory() -> Tuple[
nonlocal stream_1 nonlocal stream_1
stream_1 = stream stream_1 = stream
host_0, host_1 = await host_pair_factory() host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler) host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])

View File

@ -1,3 +1,5 @@
import asyncio
import pytest import pytest
from .utils import connect from .utils import connect
@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds):
# Test: `disconnect` from Go # Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id()) await p2pd.control.disconnect(host.get_id())
# FIXME: Failed to handle disconnect # FIXME: Failed to handle disconnect
# assert len(host.get_network().connections) == 0 await asyncio.sleep(0.01)
assert len(host.get_network().connections) == 0

View File

@ -2,13 +2,22 @@ import asyncio
import pytest import pytest
from tests.factories import net_stream_pair_factory from tests.factories import net_stream_pair_factory, swarm_pair_factory
@pytest.fixture @pytest.fixture
async def net_stream_pair(): async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
try: try:
yield stream_0, stream_1 yield stream_0, stream_1
finally: finally:
await asyncio.gather(*[host_0.close(), host_1.close()]) await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture
async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
try:
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,93 @@
import asyncio
import pytest
from libp2p.network.exceptions import SwarmException
from tests.factories import ListeningSwarmFactory
from tests.utils import connect_swarm
@pytest.mark.asyncio
async def test_swarm_dial_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
# Test: No addr found.
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio
async def test_swarm_close_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections

View File

@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
async def connect_swarm(swarm_0, swarm_1):
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
for transport in swarm_1.listeners.values()
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
async def connect(node1, node2): async def connect(node1, node2):
""" """
Connect node1 to node2 Connect node1 to node2