Add mplex tests and fix error in SwarmConn.close
This commit is contained in:
parent
d61327f5f9
commit
a9ad37bc6f
|
@ -17,8 +17,8 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr
|
|||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
||||
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
|
@ -74,8 +74,8 @@ def initialize_default_swarm(
|
|||
key_pair: KeyPair,
|
||||
id_opt: ID = None,
|
||||
transport_opt: Sequence[str] = None,
|
||||
muxer_opt: Mapping[TProtocol, MuxerClassType] = None,
|
||||
sec_opt: Mapping[TProtocol, ISecureTransport] = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
sec_opt: TSecurityOptions = None,
|
||||
peerstore_opt: IPeerStore = None,
|
||||
disc_opt: IPeerRouting = None,
|
||||
) -> Swarm:
|
||||
|
@ -114,7 +114,7 @@ async def new_node(
|
|||
key_pair: KeyPair = None,
|
||||
swarm_opt: INetwork = None,
|
||||
transport_opt: Sequence[str] = None,
|
||||
muxer_opt: Mapping[TProtocol, MuxerClassType] = None,
|
||||
muxer_opt: Mapping[TProtocol, TMuxerClass] = None,
|
||||
sec_opt: Mapping[TProtocol, ISecureTransport] = None,
|
||||
peerstore_opt: IPeerStore = None,
|
||||
disc_opt: IPeerRouting = None,
|
||||
|
|
|
@ -141,4 +141,4 @@ class BasicHost(IHost):
|
|||
MultiselectCommunicator(net_stream)
|
||||
)
|
||||
net_stream.set_protocol(protocol)
|
||||
asyncio.ensure_future(handler(net_stream))
|
||||
await handler(net_stream)
|
||||
|
|
|
@ -66,10 +66,19 @@ class SwarmConn(INetConn):
|
|||
|
||||
await self.close()
|
||||
|
||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
||||
try:
|
||||
await self.swarm.common_stream_handler(net_stream)
|
||||
# TODO: More exact exceptions
|
||||
except Exception:
|
||||
# TODO: Emit logs.
|
||||
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
|
||||
self.remove_stream(net_stream)
|
||||
|
||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||
net_stream = await self._add_stream(muxed_stream)
|
||||
if self.swarm.common_stream_handler is not None:
|
||||
await self.run_task(self.swarm.common_stream_handler(net_stream))
|
||||
await self.run_task(self._call_stream_handler(net_stream))
|
||||
|
||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||
net_stream = NetStream(muxed_stream)
|
||||
|
@ -97,4 +106,6 @@ class SwarmConn(INetConn):
|
|||
|
||||
# TODO: Called by `Stream` whenever it is time to remove the stream.
|
||||
def remove_stream(self, stream: NetStream) -> None:
|
||||
if stream not in self.streams:
|
||||
return
|
||||
self.streams.remove(stream)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.peer.id import ID
|
||||
|
@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
|||
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.transport.typing import TSecurityOptions
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
|
||||
|
@ -31,15 +31,14 @@ class SecurityMultistream(ABC):
|
|||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
|
||||
def __init__(
|
||||
self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport]
|
||||
) -> None:
|
||||
def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None:
|
||||
self.transports = OrderedDict()
|
||||
self.multiselect = Multiselect()
|
||||
self.multiselect_client = MultiselectClient()
|
||||
|
||||
for protocol, transport in secure_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
if secure_transports_by_protocol is not None:
|
||||
for protocol, transport in secure_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
|
||||
def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None:
|
||||
"""
|
||||
|
|
|
@ -29,9 +29,6 @@ class Mplex(IMuxedConn):
|
|||
|
||||
secured_conn: ISecureConn
|
||||
peer_id: ID
|
||||
# TODO: `dataIn` in go implementation. Should be size of 8.
|
||||
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
|
||||
# to let the `MplexStream`s know that EOF arrived (#235).
|
||||
next_channel_id: int
|
||||
streams: Dict[StreamID, MplexStream]
|
||||
streams_lock: asyncio.Lock
|
||||
|
|
|
@ -24,6 +24,7 @@ class MplexStream(IMuxedStream):
|
|||
|
||||
close_lock: asyncio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
incoming_data: "asyncio.Queue[bytes]"
|
||||
|
||||
event_local_closed: asyncio.Event
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Mapping, Type
|
||||
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.peer.id import ID
|
||||
|
@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
|||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
from libp2p.transport.typing import TMuxerClass, TMuxerOptions
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
from .abc import IMuxedConn
|
||||
|
||||
MuxerClassType = Type[IMuxedConn]
|
||||
|
||||
# FIXME: add negotiate timeout to `MuxerMultistream`
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 60
|
||||
|
||||
|
@ -24,20 +22,19 @@ class MuxerMultistream:
|
|||
"""
|
||||
|
||||
# NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2.
|
||||
transports: "OrderedDict[TProtocol, MuxerClassType]"
|
||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
|
||||
def __init__(
|
||||
self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType]
|
||||
) -> None:
|
||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None:
|
||||
self.transports = OrderedDict()
|
||||
self.multiselect = Multiselect()
|
||||
self.multiselect_client = MultiselectClient()
|
||||
for protocol, transport in muxer_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
if muxer_transports_by_protocol is not None:
|
||||
for protocol, transport in muxer_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
|
||||
def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None:
|
||||
def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None:
|
||||
"""
|
||||
Add a protocol and its corresponding transport to multistream-select(multiselect).
|
||||
The order that a protocol is added is exactly the precedence it is negotiated in
|
||||
|
@ -51,7 +48,7 @@ class MuxerMultistream:
|
|||
self.transports[protocol] = transport
|
||||
self.multiselect.add_handler(protocol, None)
|
||||
|
||||
async def select_transport(self, conn: IRawConnection) -> MuxerClassType:
|
||||
async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
|
||||
"""
|
||||
Select a transport that both us and the node on the
|
||||
other end of conn support and agree on
|
||||
|
|
|
@ -1,4 +1,11 @@
|
|||
from asyncio import StreamReader, StreamWriter
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Awaitable, Callable, Mapping, Type
|
||||
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
|
||||
TMuxerClass = Type[IMuxedConn]
|
||||
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
from typing import Mapping
|
||||
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.security.security_multistream import SecurityMultistream
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
|
||||
from libp2p.stream_muxer.muxer_multistream import MuxerMultistream
|
||||
from libp2p.transport.exceptions import (
|
||||
HandshakeFailure,
|
||||
MuxerUpgradeFailure,
|
||||
SecurityUpgradeFailure,
|
||||
)
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
|
||||
|
||||
from .listener_interface import IListener
|
||||
from .transport_interface import ITransport
|
||||
|
@ -25,8 +22,8 @@ class TransportUpgrader:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport],
|
||||
muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType],
|
||||
secure_transports_by_protocol: TSecurityOptions,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
):
|
||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
||||
|
|
|
@ -15,6 +15,8 @@ from libp2p.pubsub.pubsub import Pubsub
|
|||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
||||
from libp2p.transport.typing import TMuxerOptions
|
||||
from libp2p.typing import TProtocol
|
||||
from tests.configs import LISTEN_MADDR
|
||||
from tests.pubsub.configs import (
|
||||
|
@ -34,10 +36,10 @@ def security_transport_factory(
|
|||
return {secio.ID: secio.Transport(key_pair)}
|
||||
|
||||
|
||||
def SwarmFactory(is_secure: bool) -> Swarm:
|
||||
def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm:
|
||||
key_pair = generate_new_rsa_identity()
|
||||
sec_opt = security_transport_factory(False, key_pair)
|
||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
|
||||
|
||||
|
||||
class ListeningSwarmFactory(factory.Factory):
|
||||
|
@ -45,17 +47,22 @@ class ListeningSwarmFactory(factory.Factory):
|
|||
model = Swarm
|
||||
|
||||
@classmethod
|
||||
async def create_and_listen(cls, is_secure: bool) -> Swarm:
|
||||
swarm = SwarmFactory(is_secure)
|
||||
async def create_and_listen(
|
||||
cls, is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Swarm:
|
||||
swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt)
|
||||
await swarm.listen(LISTEN_MADDR)
|
||||
return swarm
|
||||
|
||||
@classmethod
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[Swarm, ...]:
|
||||
return await asyncio.gather(
|
||||
*[cls.create_and_listen(is_secure) for _ in range(number)]
|
||||
*[
|
||||
cls.create_and_listen(is_secure, muxer_opt=muxer_opt)
|
||||
for _ in range(number)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -112,8 +119,12 @@ class PubsubFactory(factory.Factory):
|
|||
cache_size = None
|
||||
|
||||
|
||||
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
|
||||
async def swarm_pair_factory(
|
||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[Swarm, Swarm]:
|
||||
swarms = await ListeningSwarmFactory.create_batch_and_listen(
|
||||
is_secure, 2, muxer_opt=muxer_opt
|
||||
)
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
return swarms[0], swarms[1]
|
||||
|
||||
|
@ -130,7 +141,7 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
|||
|
||||
|
||||
async def swarm_conn_pair_factory(
|
||||
is_secure
|
||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
||||
swarms = await swarm_pair_factory(is_secure)
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
|
@ -138,6 +149,14 @@ async def swarm_conn_pair_factory(
|
|||
return conn_0, swarms[0], conn_1, swarms[1]
|
||||
|
||||
|
||||
async def mplex_conn_pair_factory(is_secure):
|
||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
|
||||
is_secure, muxer_opt=muxer_opt
|
||||
)
|
||||
return conn_0.conn, swarm_0, conn_1.conn, swarm_1
|
||||
|
||||
|
||||
async def net_stream_pair_factory(
|
||||
is_secure: bool
|
||||
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||
|
|
|
@ -41,3 +41,5 @@ async def test_swarm_conn_streams(swarm_conn_pair):
|
|||
assert len(await conn_0.get_streams()) == 1
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
assert len(await conn_0.get_streams()) == 0
|
||||
# Nothing happen if `stream_0_1` is not present or already removed.
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
|
|
0
tests/stream_muxer/__init__.py
Normal file
0
tests/stream_muxer/__init__.py
Normal file
16
tests/stream_muxer/conftest.py
Normal file
16
tests/stream_muxer/conftest.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.factories import mplex_conn_pair_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mplex_conn_pair(is_host_secure):
|
||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
||||
is_host_secure
|
||||
)
|
||||
try:
|
||||
yield mplex_conn_0, mplex_conn_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
6
tests/stream_muxer/test_mplex_conn.py
Normal file
6
tests/stream_muxer/test_mplex_conn.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mplex_conn(mplex_conn_pair):
|
||||
conn_0, conn_1 = mplex_conn_pair
|
Loading…
Reference in New Issue
Block a user