Add tests for SwarmConn

This commit is contained in:
mhchia 2019-09-17 23:38:11 +08:00
parent b8b5ac5e06
commit d61327f5f9
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
5 changed files with 79 additions and 9 deletions

View File

@ -43,11 +43,15 @@ class SwarmConn(INetConn):
# We *could* optimize this but it really isn't worth it. # We *could* optimize this but it really isn't worth it.
for stream in self.streams: for stream in self.streams:
await stream.reset() await stream.reset()
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
asyncio.ensure_future(self._notify_disconnected())
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
asyncio.ensure_future(self._notify_disconnected())
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while True: while True:
@ -70,7 +74,6 @@ class SwarmConn(INetConn):
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream) self.streams.add(net_stream)
# Call notifiers since event occurred
for notifee in self.swarm.notifees: for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream) await notifee.opened_stream(self.swarm, net_stream)
return net_stream return net_stream
@ -91,3 +94,7 @@ class SwarmConn(INetConn):
async def get_streams(self) -> Tuple[NetStream, ...]: async def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams) return tuple(self.streams)
# TODO: Called by `Stream` whenever it is time to remove the stream.
def remove_stream(self, stream: NetStream) -> None:
self.streams.remove(stream)

View File

@ -66,3 +66,7 @@ class NetStream(INetStream):
async def reset(self) -> None: async def reset(self) -> None:
await self.muxed_stream.reset() await self.muxed_stream.reset()
# TODO: `remove`: Called by close and write when the stream is in specific states.
# It notify `ClosedStream` after `SwarmConn.remove_stream` is called.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501

View File

@ -6,6 +6,7 @@ import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
@ -128,11 +129,13 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
return hosts[0], hosts[1] return hosts[0], hosts[1]
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: async def swarm_conn_pair_factory(
# host_0, host_1 = await host_pair_factory() is_secure
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] swarms = await swarm_pair_factory(is_secure)
# return mplex_conn_0, host_0, mplex_conn_1, host_1 conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return conn_0, swarms[0], conn_1, swarms[1]
async def net_stream_pair_factory( async def net_stream_pair_factory(

View File

@ -2,7 +2,11 @@ import asyncio
import pytest import pytest
from tests.factories import net_stream_pair_factory, swarm_pair_factory from tests.factories import (
net_stream_pair_factory,
swarm_conn_pair_factory,
swarm_pair_factory,
)
@pytest.fixture @pytest.fixture
@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure):
yield swarm_0, swarm_1 yield swarm_0, swarm_1
finally: finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture
async def swarm_conn_pair(is_host_secure):
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
try:
yield conn_0, conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,43 @@
import asyncio
import pytest
@pytest.mark.asyncio
async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
await conn_0.close()
await asyncio.sleep(0.01)
assert conn_0.event_closed.is_set()
assert conn_1.event_closed.is_set()
assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values()
@pytest.mark.asyncio
async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert len(await conn_0.get_streams()) == 0
assert len(await conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 1
assert len(await conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 2
assert len(await conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0)
assert len(await conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0