Merge pull request #287 from mhchia/fix/mplex-stream-close-reset

Fix close/reset behavior
This commit is contained in:
Kevin Mai-Husan Chia 2019-09-10 23:57:44 +08:00 committed by GitHub
commit 988ef8c712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 465 additions and 91 deletions

View File

@ -0,0 +1,17 @@
from libp2p.exceptions import BaseLibp2pError
class StreamError(BaseLibp2pError):
pass
class StreamEOF(StreamError, EOFError):
pass
class StreamReset(StreamError):
pass
class StreamClosed(StreamError):
pass

View File

@ -1,9 +1,18 @@
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .exceptions import StreamClosed, StreamEOF, StreamReset
from .net_stream_interface import INetStream from .net_stream_interface import INetStream
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class NetStream(INetStream): class NetStream(INetStream):
muxed_stream: IMuxedStream muxed_stream: IMuxedStream
@ -35,14 +44,22 @@ class NetStream(INetStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes of input :return: bytes of input
""" """
try:
return await self.muxed_stream.read(n) return await self.muxed_stream.read(n)
except MuxedStreamEOF as error:
raise StreamEOF from error
except MuxedStreamReset as error:
raise StreamReset from error
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """
write to stream write to stream
:return: number of bytes written :return: number of bytes written
""" """
try:
return await self.muxed_stream.write(data) return await self.muxed_stream.write(data)
except MuxedStreamClosed as error:
raise StreamClosed from error
async def close(self) -> None: async def close(self) -> None:
""" """
@ -51,5 +68,5 @@ class NetStream(INetStream):
""" """
await self.muxed_stream.close() await self.muxed_stream.close()
async def reset(self) -> bool: async def reset(self) -> None:
return await self.muxed_stream.reset() await self.muxed_stream.reset()

View File

@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser):
""" """
@abstractmethod @abstractmethod
async def reset(self) -> bool: async def reset(self) -> None:
""" """
Close both ends of the stream. Close both ends of the stream.
""" """

View File

@ -0,0 +1,25 @@
from libp2p.exceptions import BaseLibp2pError
class MuxedConnError(BaseLibp2pError):
pass
class MuxedConnShutdown(MuxedConnError):
pass
class MuxedStreamError(BaseLibp2pError):
pass
class MuxedStreamReset(MuxedStreamError):
pass
class MuxedStreamEOF(MuxedStreamError, EOFError):
pass
class MuxedStreamClosed(MuxedStreamError):
pass

View File

@ -1,17 +1,27 @@
from libp2p.exceptions import BaseLibp2pError from libp2p.stream_muxer.exceptions import (
MuxedConnError,
MuxedConnShutdown,
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
class MplexError(BaseLibp2pError): class MplexError(MuxedConnError):
pass pass
class MplexStreamReset(MplexError): class MplexShutdown(MuxedConnShutdown):
pass pass
class MplexStreamEOF(MplexError, EOFError): class MplexStreamReset(MuxedStreamReset):
pass pass
class MplexShutdown(MplexError): class MplexStreamEOF(MuxedStreamEOF):
pass
class MplexStreamClosed(MuxedStreamClosed):
pass pass

View File

@ -188,6 +188,10 @@ class Mplex(IMuxedConn):
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
continue continue
async with stream.close_lock:
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
continue
await stream.incoming_data.put(message) await stream.incoming_data.put(message)
elif flag in ( elif flag in (
HeaderTags.CloseInitiator.value, HeaderTags.CloseInitiator.value,

View File

@ -5,7 +5,7 @@ from libp2p.stream_muxer.abc import IMuxedStream
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
from .exceptions import MplexStreamEOF, MplexStreamReset from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.stream_muxer.mplex.mplex import Mplex
@ -55,22 +55,46 @@ class MplexStream(IMuxedStream):
return self.stream_id.is_initiator return self.stream_id.is_initiator
async def _wait_for_data(self) -> None: async def _wait_for_data(self) -> None:
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
task_event_remote_closed = asyncio.ensure_future(
self.event_remote_closed.wait()
)
done, pending = await asyncio.wait( # type: ignore done, pending = await asyncio.wait( # type: ignore
[ [task_event_reset, task_incoming_data_get, task_event_remote_closed],
self.event_reset.wait(),
self.event_remote_closed.wait(),
self.incoming_data.get(),
],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
for fut in pending:
fut.cancel()
if task_event_reset in done:
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
if task_incoming_data_get in done:
data = task_incoming_data_get.result()
self._buf.extend(data)
return
if task_event_remote_closed in done:
if self.event_remote_closed.is_set(): if self.event_remote_closed.is_set():
raise MplexStreamEOF raise MplexStreamEOF
# TODO: Handle timeout when deadline is used. else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
data = tuple(done)[0].result() # TODO: Handle timeout when deadline is used.
self._buf.extend(data)
async def _read_until_eof(self) -> bytes: async def _read_until_eof(self) -> bytes:
while True: while True:
@ -90,7 +114,6 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
# TODO: Add exceptions and handle/raise them in this class.
if n < 0 and n != -1: if n < 0 and n != -1:
raise ValueError( raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"
@ -99,13 +122,15 @@ class MplexStream(IMuxedStream):
raise MplexStreamReset raise MplexStreamReset
if n == -1: if n == -1:
return await self._read_until_eof() return await self._read_until_eof()
if len(self._buf) == 0: if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data() await self._wait_for_data()
# Read up to `n` bytes. # Now we are sure we have something to read.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n: while len(self._buf) < n:
if self.incoming_data.empty() or self.event_remote_closed.is_set(): try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break break
self._buf.extend(await self.incoming_data.get())
payload = self._buf[:n] payload = self._buf[:n]
self._buf = self._buf[len(payload) :] self._buf = self._buf[len(payload) :]
return bytes(payload) return bytes(payload)
@ -115,6 +140,8 @@ class MplexStream(IMuxedStream):
write to stream write to stream
:return: number of bytes written :return: number of bytes written
""" """
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data}")
flag = ( flag = (
HeaderTags.MessageInitiator HeaderTags.MessageInitiator
if self.is_initiator if self.is_initiator

View File

@ -4,7 +4,7 @@ import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.protocol_muxer.exceptions import MultiselectClientError from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.utils import set_up_nodes_by_transport_opt
PROTOCOL_ID = "/chat/1.0.0" PROTOCOL_ID = "/chat/1.0.0"
@ -101,5 +101,3 @@ async def test_chat(test):
await host_b.connect(info) await host_b.connect(info)
await test(host_a, host_b) await test(host_a, host_b)
await cleanup()

View File

@ -1,22 +1,28 @@
from typing import Dict import asyncio
from typing import Dict, Tuple
import factory import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
FLOODSUB_PROTOCOL_ID, FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS, GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID,
) )
from tests.utils import connect
def security_transport_factory( def security_transport_factory(
@ -43,6 +49,12 @@ class HostFactory(factory.Factory):
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
@classmethod
async def create_and_listen(cls) -> IHost:
host = cls()
await host.get_network().listen(LISTEN_MADDR)
return host
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
class Meta: class Meta:
@ -73,3 +85,37 @@ class PubsubFactory(factory.Factory):
router = None router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[
INetStream, BasicHost, INetStream, BasicHost
]:
protocol_id = "/example/id/1"
stream_1: INetStream
# Just a proxy, we only care about the stream
def handler(stream: INetStream) -> None:
nonlocal stream_1
stream_1 = stream
host_0, host_1 = await host_pair_factory()
host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1

View File

@ -2,13 +2,16 @@ import asyncio
import sys import sys
from typing import Union from typing import Union
from p2pclient.datastructures import StreamInfo
import pexpect import pexpect
import pytest import pytest
from libp2p.io.abc import ReadWriteCloser
from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
from tests.pubsub.configs import GOSSIPSUB_PARAMS from tests.pubsub.configs import GOSSIPSUB_PARAMS
from .daemon import Daemon, make_p2pd from .daemon import Daemon, make_p2pd
from .utils import connect
@pytest.fixture @pytest.fixture
@ -78,3 +81,71 @@ def pubsubs(num_hosts, hosts, is_gossipsub):
) )
yield _pubsubs yield _pubsubs
# TODO: Clean up # TODO: Clean up
class DaemonStream(ReadWriteCloser):
stream_info: StreamInfo
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
def __init__(
self,
stream_info: StreamInfo,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.stream_info = stream_info
self.reader = reader
self.writer = writer
async def close(self) -> None:
self.writer.close()
await self.writer.wait_closed()
async def read(self, n: int = -1) -> bytes:
return await self.reader.read(n)
async def write(self, data: bytes) -> int:
return self.writer.write(data)
@pytest.fixture
async def is_to_fail_daemon_stream():
return False
@pytest.fixture
async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream):
assert len(hosts) >= 1
assert len(p2pds) >= 1
host = hosts[0]
p2pd = p2pds[0]
protocol_id = "/protocol/id/123"
stream_py = None
stream_daemon = None
event_stream_handled = asyncio.Event()
await connect(host, p2pd)
async def daemon_stream_handler(stream_info, reader, writer):
nonlocal stream_daemon
stream_daemon = DaemonStream(stream_info, reader, writer)
event_stream_handled.set()
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
if is_to_fail_daemon_stream:
# FIXME: This is a workaround to make daemon reset the stream.
# We intentionally close the listener on the python side, it makes the connection from
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
# We need it because we want to test against `stream_py` after the remote side(daemon)
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
# some day.
listener = p2pds[0].control.control.listener
listener.close()
await listener.wait_closed()
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
if not is_to_fail_daemon_stream:
await event_stream_handled.wait()
# NOTE: If `is_to_fail_daemon_stream == True`, `stream_daemon == None`.
yield stream_py, stream_daemon

View File

@ -0,0 +1,74 @@
import asyncio
import pytest
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from tests.constants import MAX_READ_LEN
DATA = b"data"
@pytest.mark.asyncio
async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
assert (
stream_py.protocol_id is not None
and stream_py.protocol_id == stream_daemon.stream_info.proto
)
await stream_py.write(DATA)
assert (await stream_daemon.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_daemon.write(DATA)
await stream_daemon.close()
await asyncio.sleep(0.01)
assert (await stream_py.read(MAX_READ_LEN)) == DATA
# EOF
with pytest.raises(StreamEOF):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.reset()
with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.write(DATA)
await stream_py.close()
with pytest.raises(StreamClosed):
await stream_py.write(DATA)
@pytest.mark.asyncio
async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_py.reset()
with pytest.raises(StreamClosed):
await stream_py.write(DATA)
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_py.write(DATA)

View File

@ -3,7 +3,7 @@ import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.utils import set_up_nodes_by_transport_opt
@pytest.mark.asyncio @pytest.mark.asyncio
@ -34,7 +34,6 @@ async def test_simple_messages():
assert response == ("ack:" + message) assert response == ("ack:" + message)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -69,7 +68,6 @@ async def test_double_response():
assert response2 == ("ack2:" + message) assert response2 == ("ack2:" + message)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -120,7 +118,6 @@ async def test_multiple_streams():
) )
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -183,7 +180,6 @@ async def test_multiple_streams_same_initiator_different_protocols():
) )
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -264,7 +260,6 @@ async def test_multiple_streams_two_initiators():
) )
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -326,7 +321,6 @@ async def test_triangle_nodes_connection():
assert response == ("ack:" + message) assert response == ("ack:" + message)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -353,4 +347,3 @@ async def test_host_connect():
assert addr.encapsulate(ma_node_b) in node_b.get_addrs() assert addr.encapsulate(ma_node_b) in node_b.get_addrs()
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()

View File

@ -17,7 +17,7 @@ from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.notifee_interface import INotifee from libp2p.network.notifee_interface import INotifee
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
from tests.utils import cleanup, perform_two_host_set_up from tests.utils import perform_two_host_set_up
ACK = "ack:" ACK = "ack:"
@ -91,7 +91,6 @@ async def test_one_notifier():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -138,7 +137,6 @@ async def test_one_notifier_on_two_nodes():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -203,7 +201,6 @@ async def test_one_notifier_on_two_nodes_with_listen():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -235,7 +232,6 @@ async def test_two_notifiers():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -271,7 +267,6 @@ async def test_ten_notifiers():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -325,7 +320,6 @@ async def test_ten_notifiers_on_two_nodes():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -355,4 +349,3 @@ async def test_invalid_notifee():
assert response == expected_resp assert response == expected_resp
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()

View File

14
tests/network/conftest.py Normal file
View File

@ -0,0 +1,14 @@
import asyncio
import pytest
from tests.factories import net_stream_pair_factory
@pytest.fixture
async def net_stream_pair():
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
try:
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])

View File

@ -0,0 +1,122 @@
import asyncio
import pytest
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from tests.constants import MAX_READ_LEN
DATA = b"data_123"
# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its
# own file, after `generic_protocol_handler` is refactored out of `Mplex`.
@pytest.mark.asyncio
async def test_net_stream_read_write(net_stream_pair):
stream_0, stream_1 = net_stream_pair
assert (
stream_0.protocol_id is not None
and stream_0.protocol_id == stream_1.protocol_id
)
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_net_stream_read_until_eof(net_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = net_stream_pair
async def read_until_eof():
read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof())
expected_data = bytearray()
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
assert not stream_1.muxed_stream.event_remote_closed.is_set()
await stream_0.write(DATA)
await stream_0.close()
await asyncio.sleep(0.01)
assert stream_1.muxed_stream.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(StreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
with pytest.raises(StreamReset):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_net_stream_write_after_local_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
with pytest.raises(StreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
async def test_net_stream_write_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
with pytest.raises(StreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
async def test_net_stream_write_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_1.reset()
await asyncio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_0.write(DATA)

View File

@ -1,7 +1,7 @@
import pytest import pytest
from libp2p.protocol_muxer.exceptions import MultiselectClientError from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different
# protocols through the same connection # protocols through the same connection
@ -35,7 +35,6 @@ async def perform_simple_test(
assert expected_selected_protocol == stream.get_protocol() assert expected_selected_protocol == stream.get_protocol()
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -52,7 +51,6 @@ async def test_single_protocol_fails():
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
# Cleanup not reached on error # Cleanup not reached on error
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -83,4 +81,3 @@ async def test_multiple_protocol_fails():
await perform_simple_test("", protocols_for_client, protocols_for_listener) await perform_simple_test("", protocols_for_client, protocols_for_listener)
# Cleanup not reached on error # Cleanup not reached on error
await cleanup()

View File

@ -4,7 +4,7 @@ import pytest
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.factories import PubsubFactory from tests.factories import PubsubFactory
from tests.utils import cleanup, connect from tests.utils import connect
from .configs import FLOODSUB_PROTOCOL_ID from .configs import FLOODSUB_PROTOCOL_ID
@ -258,4 +258,3 @@ async def perform_test_from_obj(obj, router_factory):
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()

View File

@ -3,7 +3,7 @@ from threading import Thread
import pytest import pytest
from tests.utils import cleanup, connect from tests.utils import connect
from .dummy_account_node import DummyAccountNode from .dummy_account_node import DummyAccountNode
@ -64,7 +64,6 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
assertion_func(dummy_node) assertion_func(dummy_node)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -4,7 +4,7 @@ import pytest
from libp2p.peer.id import ID from libp2p.peer.id import ID
from tests.factories import FloodsubFactory from tests.factories import FloodsubFactory
from tests.utils import cleanup, connect from tests.utils import connect
from .floodsub_integration_test_settings import ( from .floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
@ -36,7 +36,6 @@ async def test_simple_two_nodes(pubsubs_fsub):
assert res_b.topicIDs == [topic] assert res_b.topicIDs == [topic]
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
# Initialize Pubsub with a cache_size of 4 # Initialize Pubsub with a cache_size of 4
@ -82,7 +81,6 @@ async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
assert sub_b.empty() assert sub_b.empty()
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)

View File

@ -3,7 +3,7 @@ import random
import pytest import pytest
from tests.utils import cleanup, connect from tests.utils import connect
from .configs import GossipsubParams from .configs import GossipsubParams
from .utils import dense_connect, one_to_all_connect from .utils import dense_connect, one_to_all_connect
@ -61,8 +61,6 @@ async def test_join(num_hosts, hosts, pubsubs_gsub):
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
assert topic not in gossipsubs[i].mesh assert topic not in gossipsubs[i].mesh
await cleanup()
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio @pytest.mark.asyncio
@ -81,8 +79,6 @@ async def test_leave(pubsubs_gsub):
# Test re-leave # Test re-leave
await gossipsub.leave(topic) await gossipsub.leave(topic)
await cleanup()
@pytest.mark.parametrize("num_hosts", (2,)) @pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio @pytest.mark.asyncio
@ -133,8 +129,6 @@ async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
# Check that bob is now alice's mesh peer # Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic] assert id_bob in gossipsubs[index_alice].mesh[topic]
await cleanup()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
@ -174,8 +168,6 @@ async def test_handle_prune(pubsubs_gsub, hosts):
assert id_alice not in gossipsubs[index_bob].mesh[topic] assert id_alice not in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic] assert id_bob in gossipsubs[index_alice].mesh[topic]
await cleanup()
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio @pytest.mark.asyncio
@ -210,7 +202,6 @@ async def test_dense(num_hosts, pubsubs_gsub, hosts):
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
await cleanup()
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.parametrize("num_hosts", (10,))
@ -268,8 +259,6 @@ async def test_fanout(hosts, pubsubs_gsub):
msg = await queue.get() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
await cleanup()
@pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio @pytest.mark.asyncio
@ -340,8 +329,6 @@ async def test_fanout_maintenance(hosts, pubsubs_gsub):
msg = await queue.get() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
await cleanup()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_hosts, gossipsub_params", "num_hosts, gossipsub_params",
@ -380,5 +367,3 @@ async def test_gossip_propagation(hosts, pubsubs_gsub):
# should be able to read message # should be able to read message
msg = await queue_1.get() msg = await queue_1.get()
assert msg.data == msg_content assert msg.data == msg_content
await cleanup()

View File

@ -6,7 +6,7 @@ from libp2p import new_node
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.security.insecure.transport import InsecureSession, InsecureTransport
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.utils import cleanup, connect from tests.utils import connect
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different
# protocols through the same connection # protocols through the same connection
@ -57,7 +57,6 @@ async def perform_simple_test(
assertion_func(node2_conn.secured_conn) assertion_func(node2_conn.secured_conn)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,6 +1,3 @@
import asyncio
from contextlib import suppress
import multiaddr import multiaddr
from libp2p import new_node from libp2p import new_node
@ -17,17 +14,6 @@ async def connect(node1, node2):
await node1.connect(info) await node1.connect(info)
async def cleanup():
pending = asyncio.all_tasks()
for task in pending:
task.cancel()
# Now we should await task to execute it's cancellation.
# Cancelled task raises asyncio.CancelledError that we can suppress:
with suppress(asyncio.CancelledError):
await task
async def set_up_nodes_by_transport_opt(transport_opt_list): async def set_up_nodes_by_transport_opt(transport_opt_list):
nodes_list = [] nodes_list = []
for transport_opt in transport_opt_list: for transport_opt in transport_opt_list: