Merge pull request #304 from mhchia/fix/detection-of-close

Detect closed `Mplex`
This commit is contained in:
Kevin Mai-Husan Chia 2019-09-21 18:28:16 +08:00 committed by GitHub
commit 4a838033ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 392 additions and 143 deletions

View File

@ -30,6 +30,10 @@ logger.setLevel(logging.DEBUG)
class BasicHost(IHost):
"""
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
on a stream with multistream-select right after a stream is initialized.
"""
_network: INetwork
_router: KadmeliaPeerRouter
@ -38,7 +42,6 @@ class BasicHost(IHost):
multiselect: Multiselect
multiselect_client: MultiselectClient
# default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
@ -76,6 +79,7 @@ class BasicHost(IHost):
"""
:return: all the multiaddr addresses this host is listening to
"""
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
addrs: List[multiaddr.Multiaddr] = []
@ -94,8 +98,6 @@ class BasicHost(IHost):
"""
self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on
async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream:

View File

@ -37,8 +37,8 @@ class RawConnection(IRawConnection):
async with self._drain_lock:
try:
await self.writer.drain()
except ConnectionResetError:
raise RawConnError()
except ConnectionResetError as error:
raise RawConnError(error)
async def read(self, n: int = -1) -> bytes:
"""

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm # noqa: F401
@ -34,17 +35,28 @@ class SwarmConn(INetConn):
if self.event_closed.is_set():
return
self.event_closed.set()
self.swarm.remove_conn(self)
await self.conn.close()
# This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it.
for stream in self.streams:
await stream.reset()
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
asyncio.ensure_future(self._notify_disconnected())
for task in self._tasks:
task.cancel()
# TODO: Reset streams for local.
# TODO: Notify closed.
async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True:
try:
stream = await self.conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream))
@ -57,11 +69,16 @@ class SwarmConn(INetConn):
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream)
self.streams.add(net_stream)
# Call notifiers since event occurred
for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream)
return net_stream
async def _notify_disconnected(self) -> None:
for notifee in self.swarm.notifees:
await notifee.disconnected(self.swarm, self.conn)
async def start(self) -> None:
await self.run_task(self._handle_new_streams())

View File

@ -272,13 +272,18 @@ class Swarm(INetwork):
async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections:
return
# TODO: Should be changed to close multisple connections,
# if we have several connections per peer in the future.
connection = self.connections[peer_id]
del self.connections[peer_id]
await connection.close()
logger.debug("successfully close the connection to peer %s", peer_id)
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
"""
Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected",
and start to monitor the connection for its new streams and disconnection.
"""
swarm_conn = SwarmConn(muxed_conn, self)
# Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn
@ -288,3 +293,14 @@ class Swarm(INetwork):
await notifee.connected(self, muxed_conn)
await swarm_conn.start()
return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None:
"""
Simply remove the connection from Swarm's records, without closing the connection.
"""
peer_id = swarm_conn.conn.peer_id
if peer_id not in self.connections:
return
# TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future.
del self.connections[peer_id]

View File

@ -5,7 +5,7 @@ class MuxedConnError(BaseLibp2pError):
pass
class MuxedConnShutdown(MuxedConnError):
class MuxedConnUnavailable(MuxedConnError):
pass

View File

@ -1,6 +1,6 @@
from libp2p.stream_muxer.exceptions import (
MuxedConnError,
MuxedConnShutdown,
MuxedConnUnavailable,
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
@ -11,7 +11,7 @@ class MplexError(MuxedConnError):
pass
class MplexShutdown(MuxedConnShutdown):
class MplexUnavailable(MuxedConnUnavailable):
pass

View File

@ -1,9 +1,10 @@
import asyncio
from typing import Any # noqa: F401
from typing import Dict, List, Optional, Tuple
from typing import Awaitable, Dict, List, Optional, Tuple
from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError
from libp2p.network.connection.exceptions import RawConnError
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
@ -17,6 +18,7 @@ from libp2p.utils import (
from .constants import HeaderTags
from .datastructures import StreamID
from .exceptions import MplexUnavailable
from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
@ -36,7 +38,8 @@ class Mplex(IMuxedConn):
streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event
event_shutting_down: asyncio.Event
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
@ -60,7 +63,8 @@ class Mplex(IMuxedConn):
self.streams = {}
self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event()
self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self._tasks = []
@ -75,16 +79,20 @@ class Mplex(IMuxedConn):
"""
close the stream muxer and underlying secured connection
"""
for task in self._tasks:
task.cancel()
if self.event_shutting_down.is_set():
return
# Set the `event_shutting_down`, to allow graceful shutdown.
self.event_shutting_down.set()
await self.secured_conn.close()
# Blocked until `close` is finally set.
await self.event_closed.wait()
def is_closed(self) -> bool:
"""
check connection is fully closed
:return: true if successful
"""
raise NotImplementedError()
return self.event_closed.is_set()
def _get_next_channel_id(self) -> int:
"""
@ -114,11 +122,29 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream:
"""
accepts a muxed stream opened by the other end
"""
return await self.new_stream_queue.get()
return await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.get()
)
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -137,7 +163,9 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data)
return await self.write_to_stream(_bytes)
return await self._wait_until_shutting_down_or_closed(
self.write_to_stream(_bytes)
)
async def write_to_stream(self, _bytes: bytes) -> int:
"""
@ -152,55 +180,112 @@ class Mplex(IMuxedConn):
"""
Read a message off of the secured connection and add it to the corresponding message buffer
"""
# TODO Deal with other types of messages using flag (currently _)
while True:
channel_id, flag, message = await self.read_message()
if channel_id is not None and flag is not None and message is not None:
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
is_stream_id_seen: bool
stream: MplexStream
async with self.streams_lock:
is_stream_id_seen = stream_id in self.streams
if is_stream_id_seen:
stream = self.streams[stream_id]
# Other consequent stream message should wait until the stream get accepted
# TODO: Handle more tags, and refactor `HeaderTags`
if flag == HeaderTags.NewStream.value:
if is_stream_id_seen:
# `NewStream` for the same id is received twice...
# TODO: Shutdown
pass
mplex_stream = await self._initialize_stream(
stream_id, message.decode()
try:
await self._handle_incoming_message()
except MplexUnavailable:
break
# Force context switch
await asyncio.sleep(0)
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
async def read_message(self) -> Tuple[int, int, bytes]:
"""
Read a single message off of the secured connection
:return: stream_id, flag, message contents
"""
# FIXME: No timeout is used in Go implementation.
try:
header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
# TODO: Check if `self` is shutdown.
await self.new_stream_queue.put(mplex_stream)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
"failed to read messages correctly from the underlying connection"
) from error
except asyncio.TimeoutError as error:
raise MplexUnavailable(
"failed to read more message body within the timeout"
) from error
flag = header & 0x07
channel_id = header >> 3
return channel_id, flag, message
async def _handle_incoming_message(self) -> None:
"""
Read and handle a new incoming message.
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
self.read_message()
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value:
await self._handle_new_stream(stream_id, message)
elif flag in (
HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
if not is_stream_id_seen:
await self._handle_message(stream_id, message)
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
await self._handle_close(stream_id)
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
await self._handle_reset(stream_id)
else:
# Receives messages with an unknown flag
# TODO: logging
async with self.streams_lock:
if stream_id in self.streams:
stream = self.streams[stream_id]
await stream.reset()
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
if stream_id in self.streams:
# `NewStream` for the same id is received twice...
raise MplexUnavailable(
f"received NewStream message for existing stream: {stream_id}"
)
mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.put(mplex_stream)
)
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
continue
return
stream = self.streams[stream_id]
async with stream.close_lock:
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
continue
await stream.incoming_data.put(message)
elif flag in (
HeaderTags.CloseInitiator.value,
HeaderTags.CloseReceiver.value,
):
if not is_stream_id_seen:
continue
return
await self._wait_until_shutting_down_or_closed(
stream.incoming_data.put(message)
)
async def _handle_close(self, stream_id: StreamID) -> None:
async with self.streams_lock:
if stream_id not in self.streams:
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
# NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection.
async with stream.close_lock:
if stream.event_remote_closed.is_set():
continue
return
is_local_closed: bool
async with stream.close_lock:
stream.event_remote_closed.set()
@ -210,16 +295,16 @@ class Mplex(IMuxedConn):
if is_local_closed:
async with self.streams_lock:
del self.streams[stream_id]
elif flag in (
HeaderTags.ResetInitiator.value,
HeaderTags.ResetReceiver.value,
):
if not is_stream_id_seen:
async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock:
if stream_id not in self.streams:
# This is *ok*. We forget the stream on reset.
continue
return
stream = self.streams[stream_id]
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
# TODO: Why? Only if remote is not closed before then reset.
stream.event_reset.set()
stream.event_remote_closed.set()
@ -228,36 +313,16 @@ class Mplex(IMuxedConn):
stream.event_local_closed.set()
async with self.streams_lock:
del self.streams[stream_id]
else:
# TODO: logging
if is_stream_id_seen:
await stream.reset()
# Force context switch
await asyncio.sleep(0)
async def read_message(self) -> Tuple[int, int, bytes]:
"""
Read a single message off of the secured connection
:return: stream_id, flag, message contents
"""
# FIXME: No timeout is used in Go implementation.
# Timeout is set to a relatively small value to alleviate wait time to exit
# loop in handle_incoming
try:
header = await decode_uvarint_from_stream(self.secured_conn)
except ParseError:
return None, None, None
try:
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
# TODO: Investigate what we should do if time is out.
return None, None, None
flag = header & 0x07
channel_id = header >> 3
return channel_id, flag, message
async def _cleanup(self) -> None:
if not self.event_shutting_down.is_set():
self.event_shutting_down.set()
async with self.streams_lock:
for stream in self.streams.values():
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_remote_closed.set()
stream.event_reset.set()
stream.event_local_closed.set()
self.streams = None
self.event_closed.set()

View File

@ -204,6 +204,10 @@ class MplexStream(IMuxedStream):
self.event_remote_closed.set()
async with self.mplex_conn.streams_lock:
if (
self.mplex_conn.streams is not None
and self.stream_id in self.mplex_conn.streams
):
del self.mplex_conn.streams[self.stream_id]
# TODO deadline not in use

View File

@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
if shift > SHIFT_64_BIT_MAX:
raise ParseError("TODO: better exception msg: Integer is too large...")
byte = await reader.read(1)
try:
byte = await read_exactly(reader, 1)
value = byte[0]
except IndexError:
raise ParseError(
"Unexpected end of stream while parsing LEB128 encoded integer"
)
res += (value & LOW_MASK) << shift

View File

@ -6,15 +6,14 @@ import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import (
@ -22,7 +21,7 @@ from tests.pubsub.configs import (
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
)
from tests.utils import connect
from tests.utils import connect, connect_swarm
def security_transport_factory(
@ -34,12 +33,31 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)}
def swarm_factory(is_secure: bool):
def SwarmFactory(is_secure: bool) -> Swarm:
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
sec_opt = security_transport_factory(False, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
class ListeningSwarmFactory(factory.Factory):
class Meta:
model = Swarm
@classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm:
swarm = SwarmFactory(is_secure)
await swarm.listen(LISTEN_MADDR)
return swarm
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[Swarm, ...]:
return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)]
)
class HostFactory(factory.Factory):
class Meta:
model = BasicHost
@ -47,13 +65,19 @@ class HostFactory(factory.Factory):
class Params:
is_secure = False
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
@classmethod
async def create_and_listen(cls) -> IHost:
host = cls()
await host.get_network().listen(LISTEN_MADDR)
return host
async def create_and_listen(cls, is_secure: bool) -> BasicHost:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1)
return BasicHost(swarms[0])
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number)
return tuple(BasicHost(swarm) for swarm in range(swarms))
class FloodsubFactory(factory.Factory):
@ -87,24 +111,33 @@ class PubsubFactory(factory.Factory):
cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
*[
HostFactory.create_and_listen(is_secure),
HostFactory.create_and_listen(is_secure),
]
)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
# host_0, host_1 = await host_pair_factory()
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
# return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[
INetStream, BasicHost, INetStream, BasicHost
]:
async def net_stream_pair_factory(
is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
protocol_id = "/example/id/1"
stream_1: INetStream
@ -114,7 +147,7 @@ async def net_stream_pair_factory() -> Tuple[
nonlocal stream_1
stream_1 = stream
host_0, host_1 = await host_pair_factory()
host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])

View File

@ -1,3 +1,5 @@
import asyncio
import pytest
from .utils import connect
@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds):
# Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id())
# FIXME: Failed to handle disconnect
# assert len(host.get_network().connections) == 0
await asyncio.sleep(0.01)
assert len(host.get_network().connections) == 0

View File

@ -2,13 +2,22 @@ import asyncio
import pytest
from tests.factories import net_stream_pair_factory
from tests.factories import net_stream_pair_factory, swarm_pair_factory
@pytest.fixture
async def net_stream_pair():
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
try:
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture
async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
try:
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,93 @@
import asyncio
import pytest
from libp2p.network.exceptions import SwarmException
from tests.factories import ListeningSwarmFactory
from tests.utils import connect_swarm
@pytest.mark.asyncio
async def test_swarm_dial_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
# Test: No addr found.
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio
async def test_swarm_close_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections
# Test: Remove twice. There should not be errors.
swarm_0.remove_conn(conn_0)
assert swarm_1.get_peer_id() not in swarm_0.connections

View File

@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN
async def connect_swarm(swarm_0, swarm_1):
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
for transport in swarm_1.listeners.values()
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
async def connect(node1, node2):
"""
Connect node1 to node2