Add mplex tests and fix error in SwarmConn.close

This commit is contained in:
mhchia 2019-09-18 15:44:45 +08:00
parent d61327f5f9
commit a9ad37bc6f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
14 changed files with 96 additions and 44 deletions

View File

@ -17,8 +17,8 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.muxer_multistream import MuxerClassType
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -74,8 +74,8 @@ def initialize_default_swarm(
key_pair: KeyPair, key_pair: KeyPair,
id_opt: ID = None, id_opt: ID = None,
transport_opt: Sequence[str] = None, transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None, muxer_opt: TMuxerOptions = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None, disc_opt: IPeerRouting = None,
) -> Swarm: ) -> Swarm:
@ -114,7 +114,7 @@ async def new_node(
key_pair: KeyPair = None, key_pair: KeyPair = None,
swarm_opt: INetwork = None, swarm_opt: INetwork = None,
transport_opt: Sequence[str] = None, transport_opt: Sequence[str] = None,
muxer_opt: Mapping[TProtocol, MuxerClassType] = None, muxer_opt: Mapping[TProtocol, TMuxerClass] = None,
sec_opt: Mapping[TProtocol, ISecureTransport] = None, sec_opt: Mapping[TProtocol, ISecureTransport] = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None, disc_opt: IPeerRouting = None,

View File

@ -141,4 +141,4 @@ class BasicHost(IHost):
MultiselectCommunicator(net_stream) MultiselectCommunicator(net_stream)
) )
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream)) await handler(net_stream)

View File

@ -66,10 +66,19 @@ class SwarmConn(INetConn):
await self.close() await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = await self._add_stream(muxed_stream) net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None: if self.swarm.common_stream_handler is not None:
await self.run_task(self.swarm.common_stream_handler(net_stream)) await self.run_task(self._call_stream_handler(net_stream))
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
@ -97,4 +106,6 @@ class SwarmConn(INetConn):
# TODO: Called by `Stream` whenever it is time to remove the stream. # TODO: Called by `Stream` whenever it is time to remove the stream.
def remove_stream(self, stream: NetStream) -> None: def remove_stream(self, stream: NetStream) -> None:
if stream not in self.streams:
return
self.streams.remove(stream) self.streams.remove(stream)

View File

@ -1,6 +1,5 @@
from abc import ABC from abc import ABC
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.transport.typing import TSecurityOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -31,15 +31,14 @@ class SecurityMultistream(ABC):
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
def __init__( def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None:
self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport]
) -> None:
self.transports = OrderedDict() self.transports = OrderedDict()
self.multiselect = Multiselect() self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient() self.multiselect_client = MultiselectClient()
for protocol, transport in secure_transports_by_protocol.items(): if secure_transports_by_protocol is not None:
self.add_transport(protocol, transport) for protocol, transport in secure_transports_by_protocol.items():
self.add_transport(protocol, transport)
def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None: def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None:
""" """

View File

@ -29,9 +29,6 @@ class Mplex(IMuxedConn):
secured_conn: ISecureConn secured_conn: ISecureConn
peer_id: ID peer_id: ID
# TODO: `dataIn` in go implementation. Should be size of 8.
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
# to let the `MplexStream`s know that EOF arrived (#235).
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock

View File

@ -24,6 +24,7 @@ class MplexStream(IMuxedStream):
close_lock: asyncio.Lock close_lock: asyncio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]" incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event event_local_closed: asyncio.Event

View File

@ -1,5 +1,4 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping, Type
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.transport.typing import TMuxerClass, TMuxerOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .abc import IMuxedConn from .abc import IMuxedConn
MuxerClassType = Type[IMuxedConn]
# FIXME: add negotiate timeout to `MuxerMultistream` # FIXME: add negotiate timeout to `MuxerMultistream`
DEFAULT_NEGOTIATE_TIMEOUT = 60 DEFAULT_NEGOTIATE_TIMEOUT = 60
@ -24,20 +22,19 @@ class MuxerMultistream:
""" """
# NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2. # NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2.
transports: "OrderedDict[TProtocol, MuxerClassType]" transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
def __init__( def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None:
self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType]
) -> None:
self.transports = OrderedDict() self.transports = OrderedDict()
self.multiselect = Multiselect() self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient() self.multiselect_client = MultiselectClient()
for protocol, transport in muxer_transports_by_protocol.items(): if muxer_transports_by_protocol is not None:
self.add_transport(protocol, transport) for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None: def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None:
""" """
Add a protocol and its corresponding transport to multistream-select(multiselect). Add a protocol and its corresponding transport to multistream-select(multiselect).
The order that a protocol is added is exactly the precedence it is negotiated in The order that a protocol is added is exactly the precedence it is negotiated in
@ -51,7 +48,7 @@ class MuxerMultistream:
self.transports[protocol] = transport self.transports[protocol] = transport
self.multiselect.add_handler(protocol, None) self.multiselect.add_handler(protocol, None)
async def select_transport(self, conn: IRawConnection) -> MuxerClassType: async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
""" """
Select a transport that both us and the node on the Select a transport that both us and the node on the
other end of conn support and agree on other end of conn support and agree on

View File

@ -1,4 +1,11 @@
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable from typing import Awaitable, Callable, Mapping, Type
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -1,19 +1,16 @@
from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.security.security_multistream import SecurityMultistream from libp2p.security.security_multistream import SecurityMultistream
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream from libp2p.stream_muxer.muxer_multistream import MuxerMultistream
from libp2p.transport.exceptions import ( from libp2p.transport.exceptions import (
HandshakeFailure, HandshakeFailure,
MuxerUpgradeFailure, MuxerUpgradeFailure,
SecurityUpgradeFailure, SecurityUpgradeFailure,
) )
from libp2p.typing import TProtocol from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
from .listener_interface import IListener from .listener_interface import IListener
from .transport_interface import ITransport from .transport_interface import ITransport
@ -25,8 +22,8 @@ class TransportUpgrader:
def __init__( def __init__(
self, self,
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport], secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType], muxer_transports_by_protocol: TMuxerOptions,
): ):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)

View File

@ -15,6 +15,8 @@ from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.transport.typing import TMuxerOptions
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
@ -34,10 +36,10 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
def SwarmFactory(is_secure: bool) -> Swarm: def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm:
key_pair = generate_new_rsa_identity() key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(False, key_pair) sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt) return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
class ListeningSwarmFactory(factory.Factory): class ListeningSwarmFactory(factory.Factory):
@ -45,17 +47,22 @@ class ListeningSwarmFactory(factory.Factory):
model = Swarm model = Swarm
@classmethod @classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm: async def create_and_listen(
swarm = SwarmFactory(is_secure) cls, is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Swarm:
swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt)
await swarm.listen(LISTEN_MADDR) await swarm.listen(LISTEN_MADDR)
return swarm return swarm
@classmethod @classmethod
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]: ) -> Tuple[Swarm, ...]:
return await asyncio.gather( return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)] *[
cls.create_and_listen(is_secure, muxer_opt=muxer_opt)
for _ in range(number)
]
) )
@ -112,8 +119,12 @@ class PubsubFactory(factory.Factory):
cache_size = None cache_size = None
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: async def swarm_pair_factory(
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
)
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1] return swarms[0], swarms[1]
@ -130,7 +141,7 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
swarms = await swarm_pair_factory(is_secure) swarms = await swarm_pair_factory(is_secure)
conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
@ -138,6 +149,14 @@ async def swarm_conn_pair_factory(
return conn_0, swarms[0], conn_1, swarms[1] return conn_0, swarms[0], conn_1, swarms[1]
async def mplex_conn_pair_factory(is_secure):
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
is_secure, muxer_opt=muxer_opt
)
return conn_0.conn, swarm_0, conn_1.conn, swarm_1
async def net_stream_pair_factory( async def net_stream_pair_factory(
is_secure: bool is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:

View File

@ -41,3 +41,5 @@ async def test_swarm_conn_streams(swarm_conn_pair):
assert len(await conn_0.get_streams()) == 1 assert len(await conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1) conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0 assert len(await conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1)

View File

View File

@ -0,0 +1,16 @@
import asyncio
import pytest
from tests.factories import mplex_conn_pair_factory
@pytest.fixture
async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure
)
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,6 @@
import pytest
@pytest.mark.asyncio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair