Add mplex tests and fix error in SwarmConn.close

This commit is contained in:
mhchia 2019-09-18 15:44:45 +08:00
parent d61327f5f9
commit a9ad37bc6f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
14 changed files with 96 additions and 44 deletions

View File

@ -17,8 +17,8 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr
import libp2p.security.secio.transport as secio
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol
@ -74,8 +74,8 @@ def initialize_default_swarm(
key_pair: KeyPair,
id_opt: ID = None,
transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None,
muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None,
) -> Swarm:
@ -114,7 +114,7 @@ async def new_node(
key_pair: KeyPair = None,
swarm_opt: INetwork = None,
transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None,
muxer_opt: Mapping[TProtocol, TMuxerClass] = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None,
peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None,

View File

@ -141,4 +141,4 @@ class BasicHost(IHost):
MultiselectCommunicator(net_stream)
)
net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream))
await handler(net_stream)

View File

@ -66,10 +66,19 @@ class SwarmConn(INetConn):
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self.run_task(self.swarm.common_stream_handler(net_stream))
await self.run_task(self._call_stream_handler(net_stream))
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream)
@ -97,4 +106,6 @@ class SwarmConn(INetConn):
# TODO: Called by `Stream` whenever it is time to remove the stream.
def remove_stream(self, stream: NetStream) -> None:
if stream not in self.streams:
return
self.streams.remove(stream)

View File

@ -1,6 +1,5 @@
from abc import ABC
from collections import OrderedDict
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.transport.typing import TSecurityOptions
from libp2p.typing import TProtocol
@ -31,13 +31,12 @@ class SecurityMultistream(ABC):
multiselect: Multiselect
multiselect_client: MultiselectClient
def __init__(
self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport]
) -> None:
def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
if secure_transports_by_protocol is not None:
for protocol, transport in secure_transports_by_protocol.items():
self.add_transport(protocol, transport)

View File

@ -29,9 +29,6 @@ class Mplex(IMuxedConn):
secured_conn: ISecureConn
peer_id: ID
# TODO: `dataIn` in go implementation. Should be size of 8.
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
# to let the `MplexStream`s know that EOF arrived (#235).
next_channel_id: int
streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock

View File

@ -24,6 +24,7 @@ class MplexStream(IMuxedStream):
close_lock: asyncio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event

View File

@ -1,5 +1,4 @@
from collections import OrderedDict
from typing import Mapping, Type
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.typing import TMuxerClass, TMuxerOptions
from libp2p.typing import TProtocol
from .abc import IMuxedConn
MuxerClassType = Type[IMuxedConn]
# FIXME: add negotiate timeout to `MuxerMultistream`
DEFAULT_NEGOTIATE_TIMEOUT = 60
@ -24,20 +22,19 @@ class MuxerMultistream:
"""
# NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2.
transports: "OrderedDict[TProtocol, MuxerClassType]"
transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect
multiselect_client: MultiselectClient
def __init__(
self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType]
) -> None:
def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
if muxer_transports_by_protocol is not None:
for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None:
def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None:
"""
Add a protocol and its corresponding transport to multistream-select(multiselect).
The order that a protocol is added is exactly the precedence it is negotiated in
@ -51,7 +48,7 @@ class MuxerMultistream:
self.transports[protocol] = transport
self.multiselect.add_handler(protocol, None)
async def select_transport(self, conn: IRawConnection) -> MuxerClassType:
async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
"""
Select a transport that both us and the node on the
other end of conn support and agree on

View File

@ -1,4 +1,11 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Mapping, Type
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -1,19 +1,16 @@
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream
from libp2p.stream_muxer.muxer_multistream import MuxerMultistream
from libp2p.transport.exceptions import (
HandshakeFailure,
MuxerUpgradeFailure,
SecurityUpgradeFailure,
)
from libp2p.typing import TProtocol
from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
from .listener_interface import IListener
from .transport_interface import ITransport
@ -25,8 +22,8 @@ class TransportUpgrader:
def __init__(
self,
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport],
muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType],
secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: TMuxerOptions,
):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)

View File

@ -15,6 +15,8 @@ from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.transport.typing import TMuxerOptions
from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import (
@ -34,10 +36,10 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)}
def SwarmFactory(is_secure: bool) -> Swarm:
def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm:
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(False, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
class ListeningSwarmFactory(factory.Factory):
@ -45,17 +47,22 @@ class ListeningSwarmFactory(factory.Factory):
model = Swarm
@classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm:
swarm = SwarmFactory(is_secure)
async def create_and_listen(
cls, is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Swarm:
swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt)
await swarm.listen(LISTEN_MADDR)
return swarm
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]:
return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)]
*[
cls.create_and_listen(is_secure, muxer_opt=muxer_opt)
for _ in range(number)
]
)
@ -112,8 +119,12 @@ class PubsubFactory(factory.Factory):
cache_size = None
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
@ -130,7 +141,7 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
async def swarm_conn_pair_factory(
is_secure
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
swarms = await swarm_pair_factory(is_secure)
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
@ -138,6 +149,14 @@ async def swarm_conn_pair_factory(
return conn_0, swarms[0], conn_1, swarms[1]
async def mplex_conn_pair_factory(is_secure):
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
is_secure, muxer_opt=muxer_opt
)
return conn_0.conn, swarm_0, conn_1.conn, swarm_1
async def net_stream_pair_factory(
is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:

View File

@ -41,3 +41,5 @@ async def test_swarm_conn_streams(swarm_conn_pair):
assert len(await conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1)

View File

View File

@ -0,0 +1,16 @@
import asyncio
import pytest
from tests.factories import mplex_conn_pair_factory
@pytest.fixture
async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure
)
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,6 @@
import pytest
@pytest.mark.asyncio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair