Add tests against the daemon for close/reset

This commit is contained in:
mhchia 2019-09-10 17:51:39 +08:00
parent bb0da41eda
commit df87f5adb9
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
4 changed files with 156 additions and 1 deletions

View File

@ -112,7 +112,6 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
# TODO: Add exceptions and handle/raise them in this class.
if n < 0 and n != -1: if n < 0 and n != -1:
raise ValueError( raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"

View File

@ -2,13 +2,16 @@ import asyncio
import sys import sys
from typing import Union from typing import Union
from p2pclient.datastructures import StreamInfo
import pexpect import pexpect
import pytest import pytest
from libp2p.io.abc import ReadWriteCloser
from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
from tests.pubsub.configs import GOSSIPSUB_PARAMS from tests.pubsub.configs import GOSSIPSUB_PARAMS
from .daemon import Daemon, make_p2pd from .daemon import Daemon, make_p2pd
from .utils import connect
@pytest.fixture @pytest.fixture
@ -78,3 +81,71 @@ def pubsubs(num_hosts, hosts, is_gossipsub):
) )
yield _pubsubs yield _pubsubs
# TODO: Clean up # TODO: Clean up
class DaemonStream(ReadWriteCloser):
stream_info: StreamInfo
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
def __init__(
self,
stream_info: StreamInfo,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.stream_info = stream_info
self.reader = reader
self.writer = writer
async def close(self) -> None:
self.writer.close()
await self.writer.wait_closed()
async def read(self, n: int = -1) -> bytes:
return await self.reader.read(n)
async def write(self, data: bytes) -> int:
return self.writer.write(data)
@pytest.fixture
async def is_to_fail_daemon_stream():
return False
@pytest.fixture
async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream):
assert len(hosts) >= 1
assert len(p2pds) >= 1
host = hosts[0]
p2pd = p2pds[0]
protocol_id = "/protocol/id/123"
stream_py = None
stream_daemon = None
event_stream_handled = asyncio.Event()
await connect(host, p2pd)
async def daemon_stream_handler(stream_info, reader, writer):
nonlocal stream_daemon
stream_daemon = DaemonStream(stream_info, reader, writer)
event_stream_handled.set()
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
if is_to_fail_daemon_stream:
# FIXME: This is a workaround to make daemon reset the stream.
# We intentionally close the listener on the python side, it makes the connection from
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
# We need it because we want to test against `stream_py` after the remote side(daemon)
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
# some day.
listener = p2pds[0].control.control.listener
listener.close()
await listener.wait_closed()
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
if not is_to_fail_daemon_stream:
await event_stream_handled.wait()
# NOTE: If `is_to_fail_daemon_stream == True`, `stream_daemon == None`.
yield stream_py, stream_daemon

View File

@ -0,0 +1,74 @@
import asyncio
import pytest
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from tests.constants import MAX_READ_LEN
DATA = b"data"
@pytest.mark.asyncio
async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
assert (
stream_py.protocol_id is not None
and stream_py.protocol_id == stream_daemon.stream_info.proto
)
await stream_py.write(DATA)
assert (await stream_daemon.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_daemon.write(DATA)
await stream_daemon.close()
await asyncio.sleep(0.01)
assert (await stream_py.read(MAX_READ_LEN)) == DATA
# EOF
with pytest.raises(StreamEOF):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.reset()
with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.write(DATA)
await stream_py.close()
with pytest.raises(StreamClosed):
await stream_py.write(DATA)
@pytest.mark.asyncio
async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_py.reset()
with pytest.raises(StreamClosed):
await stream_py.write(DATA)
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_py.write(DATA)

View File

@ -85,6 +85,17 @@ async def test_net_stream_read_after_remote_reset(net_stream_pair):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_net_stream_write_after_local_closed(net_stream_pair): async def test_net_stream_write_after_local_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair stream_0, stream_1 = net_stream_pair