Add tests for SwarmConn
This commit is contained in:
parent
b8b5ac5e06
commit
d61327f5f9
|
@ -43,11 +43,15 @@ class SwarmConn(INetConn):
|
||||||
# We *could* optimize this but it really isn't worth it.
|
# We *could* optimize this but it really isn't worth it.
|
||||||
for stream in self.streams:
|
for stream in self.streams:
|
||||||
await stream.reset()
|
await stream.reset()
|
||||||
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
|
||||||
asyncio.ensure_future(self._notify_disconnected())
|
|
||||||
|
|
||||||
for task in self._tasks:
|
for task in self._tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
||||||
|
asyncio.ensure_future(self._notify_disconnected())
|
||||||
|
|
||||||
async def _handle_new_streams(self) -> None:
|
async def _handle_new_streams(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
|
@ -70,7 +74,6 @@ class SwarmConn(INetConn):
|
||||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||||
net_stream = NetStream(muxed_stream)
|
net_stream = NetStream(muxed_stream)
|
||||||
self.streams.add(net_stream)
|
self.streams.add(net_stream)
|
||||||
# Call notifiers since event occurred
|
|
||||||
for notifee in self.swarm.notifees:
|
for notifee in self.swarm.notifees:
|
||||||
await notifee.opened_stream(self.swarm, net_stream)
|
await notifee.opened_stream(self.swarm, net_stream)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
@ -91,3 +94,7 @@ class SwarmConn(INetConn):
|
||||||
|
|
||||||
async def get_streams(self) -> Tuple[NetStream, ...]:
|
async def get_streams(self) -> Tuple[NetStream, ...]:
|
||||||
return tuple(self.streams)
|
return tuple(self.streams)
|
||||||
|
|
||||||
|
# TODO: Called by `Stream` whenever it is time to remove the stream.
|
||||||
|
def remove_stream(self, stream: NetStream) -> None:
|
||||||
|
self.streams.remove(stream)
|
||||||
|
|
|
@ -66,3 +66,7 @@ class NetStream(INetStream):
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
await self.muxed_stream.reset()
|
await self.muxed_stream.reset()
|
||||||
|
|
||||||
|
# TODO: `remove`: Called by close and write when the stream is in specific states.
|
||||||
|
# It notify `ClosedStream` after `SwarmConn.remove_stream` is called.
|
||||||
|
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||||
|
|
|
@ -6,6 +6,7 @@ import factory
|
||||||
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
|
from libp2p.network.connection.swarm_connection import SwarmConn
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.network.swarm import Swarm
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
|
@ -128,11 +129,13 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||||
return hosts[0], hosts[1]
|
return hosts[0], hosts[1]
|
||||||
|
|
||||||
|
|
||||||
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
async def swarm_conn_pair_factory(
|
||||||
# host_0, host_1 = await host_pair_factory()
|
is_secure
|
||||||
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
||||||
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
swarms = await swarm_pair_factory(is_secure)
|
||||||
# return mplex_conn_0, host_0, mplex_conn_1, host_1
|
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||||
|
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||||
|
return conn_0, swarms[0], conn_1, swarms[1]
|
||||||
|
|
||||||
|
|
||||||
async def net_stream_pair_factory(
|
async def net_stream_pair_factory(
|
||||||
|
|
|
@ -2,7 +2,11 @@ import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.factories import net_stream_pair_factory, swarm_pair_factory
|
from tests.factories import (
|
||||||
|
net_stream_pair_factory,
|
||||||
|
swarm_conn_pair_factory,
|
||||||
|
swarm_pair_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure):
|
||||||
yield swarm_0, swarm_1
|
yield swarm_0, swarm_1
|
||||||
finally:
|
finally:
|
||||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def swarm_conn_pair(is_host_secure):
|
||||||
|
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
|
||||||
|
try:
|
||||||
|
yield conn_0, conn_1
|
||||||
|
finally:
|
||||||
|
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||||
|
|
43
tests/network/test_swarm_conn.py
Normal file
43
tests/network/test_swarm_conn.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_conn_close(swarm_conn_pair):
|
||||||
|
conn_0, conn_1 = swarm_conn_pair
|
||||||
|
|
||||||
|
assert not conn_0.event_closed.is_set()
|
||||||
|
assert not conn_1.event_closed.is_set()
|
||||||
|
|
||||||
|
await conn_0.close()
|
||||||
|
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
assert conn_0.event_closed.is_set()
|
||||||
|
assert conn_1.event_closed.is_set()
|
||||||
|
assert conn_0 not in conn_0.swarm.connections.values()
|
||||||
|
assert conn_1 not in conn_1.swarm.connections.values()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_conn_streams(swarm_conn_pair):
|
||||||
|
conn_0, conn_1 = swarm_conn_pair
|
||||||
|
|
||||||
|
assert len(await conn_0.get_streams()) == 0
|
||||||
|
assert len(await conn_1.get_streams()) == 0
|
||||||
|
|
||||||
|
stream_0_0 = await conn_0.new_stream()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert len(await conn_0.get_streams()) == 1
|
||||||
|
assert len(await conn_1.get_streams()) == 1
|
||||||
|
|
||||||
|
stream_0_1 = await conn_0.new_stream()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert len(await conn_0.get_streams()) == 2
|
||||||
|
assert len(await conn_1.get_streams()) == 2
|
||||||
|
|
||||||
|
conn_0.remove_stream(stream_0_0)
|
||||||
|
assert len(await conn_0.get_streams()) == 1
|
||||||
|
conn_0.remove_stream(stream_0_1)
|
||||||
|
assert len(await conn_0.get_streams()) == 0
|
Loading…
Reference in New Issue
Block a user