Noise: add noise option in the factories and tests

This commit is contained in:
mhchia 2020-02-19 23:15:03 +08:00
parent 1d2a976597
commit 13e8f496a7
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
21 changed files with 331 additions and 212 deletions

View File

@ -27,7 +27,6 @@ class SecurityMultistream(ABC):
Go implementation: github.com/libp2p/go-conn-security-multistream/ssms.go Go implementation: github.com/libp2p/go-conn-security-multistream/ssms.go
""" """
# NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2.
transports: "OrderedDict[TProtocol, ISecureTransport]" transports: "OrderedDict[TProtocol, ISecureTransport]"
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient

View File

@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast from typing import Any, AsyncIterator, Callable, Dict, List, Sequence, Tuple, cast
from async_exit_stack import AsyncExitStack from async_exit_stack import AsyncExitStack
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -33,6 +33,7 @@ from libp2p.security.noise.messages import (
NoiseHandshakePayload, NoiseHandshakePayload,
make_handshake_payload_sig, make_handshake_payload_sig,
) )
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport from libp2p.security.noise.transport import Transport as NoiseTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -41,20 +42,26 @@ from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions from libp2p.transport.typing import TMuxerOptions, TSecurityOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
from .utils import connect, connect_swarm from .utils import connect, connect_swarm
DEFAULT_SECURITY_PROTOCOL_ID = PLAINTEXT_PROTOCOL_ID
def default_key_pair_factory() -> KeyPair:
return generate_new_rsa_identity()
class IDFactory(factory.Factory): class IDFactory(factory.Factory):
class Meta: class Meta:
model = ID model = ID
peer_id_bytes = factory.LazyFunction( peer_id_bytes = factory.LazyFunction(
lambda: generate_peer_id_from(generate_new_rsa_identity()) lambda: generate_peer_id_from(default_key_pair_factory())
) )
@ -64,15 +71,6 @@ def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> Pee
return peer_store return peer_store
def security_transport_factory(
is_secure: bool, key_pair: KeyPair
) -> Dict[TProtocol, ISecureTransport]:
if not is_secure:
return {PLAINTEXT_PROTOCOL_ID: InsecureTransport(key_pair)}
else:
return {secio.ID: secio.Transport(key_pair)}
def noise_static_key_factory() -> PrivateKey: def noise_static_key_factory() -> PrivateKey:
return create_ed25519_key_pair().private_key return create_ed25519_key_pair().private_key
@ -88,15 +86,52 @@ def noise_handshake_payload_factory() -> NoiseHandshakePayload:
) )
def noise_transport_factory() -> NoiseTransport: def plaintext_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return InsecureTransport(key_pair)
def secio_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return secio.Transport(key_pair)
def noise_transport_factory(key_pair: KeyPair) -> ISecureTransport:
return NoiseTransport( return NoiseTransport(
libp2p_keypair=create_secp256k1_key_pair(), libp2p_keypair=key_pair,
noise_privkey=noise_static_key_factory(), noise_privkey=noise_static_key_factory(),
early_data=None, early_data=None,
with_noise_pipes=False, with_noise_pipes=False,
) )
def security_options_factory_factory(
protocol_id: TProtocol = None
) -> Callable[[KeyPair], TSecurityOptions]:
if protocol_id is None:
protocol_id = DEFAULT_SECURITY_PROTOCOL_ID
def security_options_factory(key_pair: KeyPair) -> TSecurityOptions:
transport_factory: Callable[[KeyPair], ISecureTransport]
if protocol_id == PLAINTEXT_PROTOCOL_ID:
transport_factory = plaintext_transport_factory
elif protocol_id == secio.ID:
transport_factory = secio_transport_factory
elif protocol_id == NOISE_PROTOCOL_ID:
transport_factory = noise_transport_factory
else:
raise Exception(f"security transport {protocol_id} is not supported")
return {protocol_id: transport_factory(key_pair)}
return security_options_factory
def mplex_transport_factory() -> TMuxerOptions:
return {MPLEX_PROTOCOL_ID: Mplex}
def default_muxer_transport_factory() -> TMuxerOptions:
return mplex_transport_factory()
@asynccontextmanager @asynccontextmanager
async def raw_conn_factory( async def raw_conn_factory(
nursery: trio.Nursery nursery: trio.Nursery
@ -124,8 +159,12 @@ async def raw_conn_factory(
async def noise_conn_factory( async def noise_conn_factory(
nursery: trio.Nursery nursery: trio.Nursery
) -> AsyncIterator[Tuple[ISecureConn, ISecureConn]]: ) -> AsyncIterator[Tuple[ISecureConn, ISecureConn]]:
local_transport = noise_transport_factory() local_transport = cast(
remote_transport = noise_transport_factory() NoiseTransport, noise_transport_factory(create_secp256k1_key_pair())
)
remote_transport = cast(
NoiseTransport, noise_transport_factory(create_secp256k1_key_pair())
)
local_secure_conn: ISecureConn = None local_secure_conn: ISecureConn = None
remote_secure_conn: ISecureConn = None remote_secure_conn: ISecureConn = None
@ -158,9 +197,9 @@ class SwarmFactory(factory.Factory):
model = Swarm model = Swarm
class Params: class Params:
is_secure = False key_pair = factory.LazyFunction(default_key_pair_factory)
key_pair = factory.LazyFunction(generate_new_rsa_identity) security_protocol = DEFAULT_SECURITY_PROTOCOL_ID
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} muxer_opt = factory.LazyFunction(default_muxer_transport_factory)
peer_id = factory.LazyAttribute(lambda o: generate_peer_id_from(o.key_pair)) peer_id = factory.LazyAttribute(lambda o: generate_peer_id_from(o.key_pair))
peerstore = factory.LazyAttribute( peerstore = factory.LazyAttribute(
@ -168,7 +207,8 @@ class SwarmFactory(factory.Factory):
) )
upgrader = factory.LazyAttribute( upgrader = factory.LazyAttribute(
lambda o: TransportUpgrader( lambda o: TransportUpgrader(
security_transport_factory(o.is_secure, o.key_pair), o.muxer_opt (security_options_factory_factory(o.security_protocol))(o.key_pair),
o.muxer_opt,
) )
) )
transport = factory.LazyFunction(TCP) transport = factory.LazyFunction(TCP)
@ -176,7 +216,10 @@ class SwarmFactory(factory.Factory):
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_and_listen( async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None cls,
key_pair: KeyPair = None,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Swarm]: ) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass # `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to # an argument explicitly with `None`. If an argument is `None`, we don't pass it to
@ -184,9 +227,11 @@ class SwarmFactory(factory.Factory):
optional_kwargs: Dict[str, Any] = {} optional_kwargs: Dict[str, Any] = {}
if key_pair is not None: if key_pair is not None:
optional_kwargs["key_pair"] = key_pair optional_kwargs["key_pair"] = key_pair
if security_protocol is not None:
optional_kwargs["security_protocol"] = security_protocol
if muxer_opt is not None: if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs) swarm = cls(**optional_kwargs)
async with background_trio_service(swarm): async with background_trio_service(swarm):
await swarm.listen(LISTEN_MADDR) await swarm.listen(LISTEN_MADDR)
yield swarm yield swarm
@ -194,12 +239,17 @@ class SwarmFactory(factory.Factory):
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None cls,
number: int,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[Swarm, ...]]: ) -> AsyncIterator[Tuple[Swarm, ...]]:
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
ctx_mgrs = [ ctx_mgrs = [
await stack.enter_async_context( await stack.enter_async_context(
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) cls.create_and_listen(
security_protocol=security_protocol, muxer_opt=muxer_opt
)
) )
for _ in range(number) for _ in range(number)
] ]
@ -211,17 +261,27 @@ class HostFactory(factory.Factory):
model = BasicHost model = BasicHost
class Params: class Params:
is_secure = False key_pair = factory.LazyFunction(default_key_pair_factory)
key_pair = factory.LazyFunction(generate_new_rsa_identity) security_protocol: TProtocol = None
muxer_opt = factory.LazyFunction(default_muxer_transport_factory)
network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure)) network = factory.LazyAttribute(
lambda o: SwarmFactory(
security_protocol=o.security_protocol, muxer_opt=o.muxer_opt
)
)
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls,
number: int,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[BasicHost, ...]]: ) -> AsyncIterator[Tuple[BasicHost, ...]]:
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms: async with SwarmFactory.create_batch_and_listen(
number, security_protocol=security_protocol, muxer_opt=muxer_opt
) as swarms:
hosts = tuple(BasicHost(swarm) for swarm in swarms) hosts = tuple(BasicHost(swarm) for swarm in swarms)
yield hosts yield hosts
@ -245,20 +305,29 @@ class RoutedHostFactory(factory.Factory):
model = RoutedHost model = RoutedHost
class Params: class Params:
is_secure = False key_pair = factory.LazyFunction(default_key_pair_factory)
security_protocol: TProtocol = None
muxer_opt = factory.LazyFunction(default_muxer_transport_factory)
network = factory.LazyAttribute( network = factory.LazyAttribute(
lambda o: HostFactory(is_secure=o.is_secure).get_network() lambda o: HostFactory(
security_protocol=o.security_protocol, muxer_opt=o.muxer_opt
).get_network()
) )
router = factory.LazyFunction(DummyRouter) router = factory.LazyFunction(DummyRouter)
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls,
number: int,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[RoutedHost, ...]]: ) -> AsyncIterator[Tuple[RoutedHost, ...]]:
routing_table = DummyRouter() routing_table = DummyRouter()
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: async with HostFactory.create_batch_and_listen(
number, security_protocol=security_protocol, muxer_opt=muxer_opt
) as hosts:
for host in hosts: for host in hosts:
routing_table._add_peer(host.get_id(), host.get_addrs()) routing_table._add_peer(host.get_id(), host.get_addrs())
routed_hosts = tuple( routed_hosts = tuple(
@ -319,11 +388,14 @@ class PubsubFactory(factory.Factory):
cls, cls,
number: int, number: int,
routers: Sequence[IPubsubRouter], routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None, cache_size: int = None,
strict_signing: bool = False, strict_signing: bool = False,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: async with HostFactory.create_batch_and_listen(
number, security_protocol=security_protocol, muxer_opt=muxer_opt
) as hosts:
# Pubsubs should exit before hosts # Pubsubs should exit before hosts
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
pubsubs = [ pubsubs = [
@ -339,17 +411,23 @@ class PubsubFactory(factory.Factory):
async def create_batch_with_floodsub( async def create_batch_with_floodsub(
cls, cls,
number: int, number: int,
is_secure: bool = False,
cache_size: int = None, cache_size: int = None,
strict_signing: bool = False, strict_signing: bool = False,
protocols: Sequence[TProtocol] = None, protocols: Sequence[TProtocol] = None,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None: if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else: else:
floodsubs = FloodsubFactory.create_batch(number) floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router( async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size, strict_signing number,
floodsubs,
cache_size,
strict_signing,
security_protocol=security_protocol,
muxer_opt=muxer_opt,
) as pubsubs: ) as pubsubs:
yield pubsubs yield pubsubs
@ -359,7 +437,6 @@ class PubsubFactory(factory.Factory):
cls, cls,
number: int, number: int,
*, *,
is_secure: bool = False,
cache_size: int = None, cache_size: int = None,
strict_signing: bool = False, strict_signing: bool = False,
protocols: Sequence[TProtocol] = None, protocols: Sequence[TProtocol] = None,
@ -371,6 +448,8 @@ class PubsubFactory(factory.Factory):
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history, gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval, heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None: if protocols is not None:
gossipsubs = GossipsubFactory.create_batch( gossipsubs = GossipsubFactory.create_batch(
@ -395,7 +474,12 @@ class PubsubFactory(factory.Factory):
) )
async with cls._create_batch_with_router( async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size, strict_signing number,
gossipsubs,
cache_size,
strict_signing,
security_protocol=security_protocol,
muxer_opt=muxer_opt,
) as pubsubs: ) as pubsubs:
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
for router in gossipsubs: for router in gossipsubs:
@ -405,10 +489,10 @@ class PubsubFactory(factory.Factory):
@asynccontextmanager @asynccontextmanager
async def swarm_pair_factory( async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None
) -> AsyncIterator[Tuple[Swarm, Swarm]]: ) -> AsyncIterator[Tuple[Swarm, Swarm]]:
async with SwarmFactory.create_batch_and_listen( async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt 2, security_protocol=security_protocol, muxer_opt=muxer_opt
) as swarms: ) as swarms:
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
yield swarms[0], swarms[1] yield swarms[0], swarms[1]
@ -416,18 +500,22 @@ async def swarm_pair_factory(
@asynccontextmanager @asynccontextmanager
async def host_pair_factory( async def host_pair_factory(
is_secure: bool security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: ) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol, muxer_opt=muxer_opt
) as hosts:
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1] yield hosts[0], hosts[1]
@asynccontextmanager @asynccontextmanager
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None
) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]: ) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
async with swarm_pair_factory(is_secure) as swarms: async with swarm_pair_factory(
security_protocol=security_protocol, muxer_opt=muxer_opt
) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
@ -435,10 +523,11 @@ async def swarm_conn_pair_factory(
@asynccontextmanager @asynccontextmanager
async def mplex_conn_pair_factory( async def mplex_conn_pair_factory(
is_secure: bool security_protocol: TProtocol = None
) -> AsyncIterator[Tuple[Mplex, Mplex]]: ) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} async with swarm_conn_pair_factory(
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: security_protocol=security_protocol, muxer_opt=default_muxer_transport_factory()
) as swarm_pair:
yield ( yield (
cast(Mplex, swarm_pair[0].muxed_conn), cast(Mplex, swarm_pair[0].muxed_conn),
cast(Mplex, swarm_pair[1].muxed_conn), cast(Mplex, swarm_pair[1].muxed_conn),
@ -447,9 +536,11 @@ async def mplex_conn_pair_factory(
@asynccontextmanager @asynccontextmanager
async def mplex_stream_pair_factory( async def mplex_stream_pair_factory(
is_secure: bool security_protocol: TProtocol = None
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: ) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: async with mplex_conn_pair_factory(
security_protocol=security_protocol
) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = cast(MplexStream, await mplex_conn_0.open_stream()) stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
await trio.sleep(0.01) await trio.sleep(0.01)
@ -463,7 +554,7 @@ async def mplex_stream_pair_factory(
@asynccontextmanager @asynccontextmanager
async def net_stream_pair_factory( async def net_stream_pair_factory(
is_secure: bool security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None
) -> AsyncIterator[Tuple[INetStream, INetStream]]: ) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1") protocol_id = TProtocol("/example/id/1")
@ -478,7 +569,9 @@ async def net_stream_pair_factory(
stream_1 = stream stream_1 = stream
await event_handler_finished.wait() await event_handler_finished.wait()
async with host_pair_factory(is_secure) as hosts: async with host_pair_factory(
security_protocol=security_protocol, muxer_opt=muxer_opt
) as hosts:
hosts[1].set_stream_handler(protocol_id, handler) hosts[1].set_stream_handler(protocol_id, handler)
stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])

View File

@ -8,6 +8,8 @@ import trio
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID
from libp2p.typing import TProtocol
from .constants import LOCALHOST_IP from .constants import LOCALHOST_IP
from .envs import GO_BIN_PATH from .envs import GO_BIN_PATH
@ -20,7 +22,7 @@ class P2PDProcess(BaseInteractiveProcess):
def __init__( def __init__(
self, self,
control_maddr: Multiaddr, control_maddr: Multiaddr,
is_secure: bool, security_protocol: TProtocol,
is_pubsub_enabled: bool = True, is_pubsub_enabled: bool = True,
is_gossipsub: bool = True, is_gossipsub: bool = True,
is_pubsub_signing: bool = False, is_pubsub_signing: bool = False,
@ -28,7 +30,7 @@ class P2PDProcess(BaseInteractiveProcess):
) -> None: ) -> None:
args = [f"-listen={control_maddr!s}"] args = [f"-listen={control_maddr!s}"]
# NOTE: To support `-insecure`, we need to hack `go-libp2p-daemon`. # NOTE: To support `-insecure`, we need to hack `go-libp2p-daemon`.
if not is_secure: if security_protocol == PLAINTEXT_PROTOCOL_ID:
args.append("-insecure=true") args.append("-insecure=true")
if is_pubsub_enabled: if is_pubsub_enabled:
args.append("-pubsub") args.append("-pubsub")
@ -85,7 +87,7 @@ class Daemon:
async def make_p2pd( async def make_p2pd(
daemon_control_port: int, daemon_control_port: int,
client_callback_port: int, client_callback_port: int,
is_secure: bool, security_protocol: TProtocol,
is_pubsub_enabled: bool = True, is_pubsub_enabled: bool = True,
is_gossipsub: bool = True, is_gossipsub: bool = True,
is_pubsub_signing: bool = False, is_pubsub_signing: bool = False,
@ -94,7 +96,7 @@ async def make_p2pd(
control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}") control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}")
p2pd_proc = P2PDProcess( p2pd_proc = P2PDProcess(
control_maddr, control_maddr,
is_secure, security_protocol,
is_pubsub_enabled, is_pubsub_enabled,
is_gossipsub, is_gossipsub,
is_pubsub_signing, is_pubsub_signing,

View File

@ -4,8 +4,8 @@ from libp2p.tools.factories import HostFactory
@pytest.fixture @pytest.fixture
def is_host_secure(): def security_protocol():
return False return None
@pytest.fixture @pytest.fixture
@ -14,6 +14,8 @@ def num_hosts():
@pytest.fixture @pytest.fixture
async def hosts(num_hosts, is_host_secure, nursery): async def hosts(num_hosts, security_protocol, nursery):
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts: async with HostFactory.create_batch_and_listen(
num_hosts, security_protocol=security_protocol
) as _hosts:
yield _hosts yield _hosts

View File

@ -92,8 +92,11 @@ async def no_common_protocol(host_a, host_b):
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)] "test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
) )
@pytest.mark.trio @pytest.mark.trio
async def test_chat(test, is_host_secure): async def test_chat(test, security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: print("!@# ", security_protocol)
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
addr = hosts[0].get_addrs()[0] addr = hosts[0].get_addrs()[0]
info = info_from_p2p_addr(addr) info = info_from_p2p_addr(addr)
await hosts[1].connect(info) await hosts[1].connect(info)

View File

@ -8,8 +8,11 @@ from libp2p.tools.factories import host_pair_factory
@pytest.mark.trio @pytest.mark.trio
async def test_ping_once(is_host_secure): async def test_ping_once(security_protocol):
async with host_pair_factory(is_host_secure) as (host_a, host_b): async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping) await stream.write(some_ping)
@ -23,8 +26,11 @@ SOME_PING_COUNT = 3
@pytest.mark.trio @pytest.mark.trio
async def test_ping_several(is_host_secure): async def test_ping_several(security_protocol):
async with host_pair_factory(is_host_secure) as (host_a, host_b): async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT): for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)

View File

@ -7,7 +7,7 @@ from libp2p.tools.factories import HostFactory, RoutedHostFactory
@pytest.mark.trio @pytest.mark.trio
async def test_host_routing_success(): async def test_host_routing_success():
async with RoutedHostFactory.create_batch_and_listen(False, 2) as hosts: async with RoutedHostFactory.create_batch_and_listen(2) as hosts:
# forces to use routing as no addrs are provided # forces to use routing as no addrs are provided
await hosts[0].connect(PeerInfo(hosts[1].get_id(), [])) await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
await hosts[1].connect(PeerInfo(hosts[0].get_id(), [])) await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
@ -15,10 +15,9 @@ async def test_host_routing_success():
@pytest.mark.trio @pytest.mark.trio
async def test_host_routing_fail(): async def test_host_routing_fail():
is_secure = False
async with RoutedHostFactory.create_batch_and_listen( async with RoutedHostFactory.create_batch_and_listen(
is_secure, 2 2
) as routed_hosts, HostFactory.create_batch_and_listen(is_secure, 1) as basic_hosts: ) as routed_hosts, HostFactory.create_batch_and_listen(1) as basic_hosts:
# routing fails because host_c does not use routing # routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure): with pytest.raises(ConnectionFailure):
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), [])) await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))

View File

@ -6,8 +6,11 @@ from libp2p.tools.factories import host_pair_factory
@pytest.mark.trio @pytest.mark.trio
async def test_identify_protocol(is_host_secure): async def test_identify_protocol(security_protocol):
async with host_pair_factory(is_host_secure) as (host_a, host_b): async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read() response = await stream.read()
await stream.close() await stream.close()

View File

@ -19,8 +19,10 @@ ACK_STR_3 = "ack_3:"
@pytest.mark.trio @pytest.mark.trio
async def test_simple_messages(is_host_secure): async def test_simple_messages(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[1].set_stream_handler( hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
) )
@ -38,8 +40,10 @@ async def test_simple_messages(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_double_response(is_host_secure): async def test_double_response(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
async def double_response_stream_handler(stream): async def double_response_stream_handler(stream):
while True: while True:
@ -78,11 +82,13 @@ async def test_double_response(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_streams(is_host_secure): async def test_multiple_streams(security_protocol):
# hosts[0] should be able to open a stream with hosts[1] and then vice versa. # hosts[0] should be able to open a stream with hosts[1] and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not overwritten # Stream IDs should be generated uniquely so that the stream state is not overwritten
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler( hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
) )
@ -115,8 +121,10 @@ async def test_multiple_streams(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_streams_same_initiator_different_protocols(is_host_secure): async def test_multiple_streams_same_initiator_different_protocols(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[1].set_stream_handler( hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
@ -161,8 +169,10 @@ async def test_multiple_streams_same_initiator_different_protocols(is_host_secur
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_streams_two_initiators(is_host_secure): async def test_multiple_streams_two_initiators(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler( hosts[0].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2) PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
) )
@ -217,8 +227,10 @@ async def test_multiple_streams_two_initiators(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_triangle_nodes_connection(is_host_secure): async def test_triangle_nodes_connection(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts: async with HostFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as hosts:
hosts[0].set_stream_handler( hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0) PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
@ -268,8 +280,10 @@ async def test_triangle_nodes_connection(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_host_connect(is_host_secure): async def test_host_connect(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
assert len(hosts[0].get_peerstore().peer_ids()) == 1 assert len(hosts[0].get_peerstore().peer_ids()) == 1
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])

View File

@ -8,18 +8,22 @@ from libp2p.tools.factories import (
@pytest.fixture @pytest.fixture
async def net_stream_pair(is_host_secure): async def net_stream_pair(security_protocol):
async with net_stream_pair_factory(is_host_secure) as net_stream_pair: async with net_stream_pair_factory(
security_protocol=security_protocol
) as net_stream_pair:
yield net_stream_pair yield net_stream_pair
@pytest.fixture @pytest.fixture
async def swarm_pair(is_host_secure): async def swarm_pair(security_protocol):
async with swarm_pair_factory(is_host_secure) as swarms: async with swarm_pair_factory(security_protocol=security_protocol) as swarms:
yield swarms yield swarms
@pytest.fixture @pytest.fixture
async def swarm_conn_pair(is_host_secure): async def swarm_conn_pair(security_protocol):
async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair: async with swarm_conn_pair_factory(
security_protocol=security_protocol
) as swarm_conn_pair:
yield swarm_conn_pair yield swarm_conn_pair

View File

@ -55,8 +55,8 @@ class MyNotifee(INotifee):
@pytest.mark.trio @pytest.mark.trio
async def test_notify(is_host_secure): async def test_notify(security_protocol):
swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)] swarms = [SwarmFactory(security_protocol=security_protocol) for _ in range(2)]
events_0_0 = [] events_0_0 = []
events_1_0 = [] events_1_0 = []

View File

@ -9,8 +9,10 @@ from libp2p.tools.utils import connect_swarm
@pytest.mark.trio @pytest.mark.trio
async def test_swarm_dial_peer(is_host_secure): async def test_swarm_dial_peer(security_protocol):
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
# Test: No addr found. # Test: No addr found.
with pytest.raises(SwarmException): with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
@ -38,8 +40,10 @@ async def test_swarm_dial_peer(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_swarm_close_peer(is_host_secure): async def test_swarm_close_peer(security_protocol):
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
# 0 <> 1 <> 2 # 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2]) await connect_swarm(swarms[1], swarms[2])
@ -90,8 +94,10 @@ async def test_swarm_remove_conn(swarm_pair):
@pytest.mark.trio @pytest.mark.trio
async def test_swarm_multiaddr(is_host_secure): async def test_swarm_multiaddr(security_protocol):
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: async with SwarmFactory.create_batch_and_listen(
3, security_protocol=security_protocol
) as swarms:
def clear(): def clear():
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id()) swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())

View File

@ -16,9 +16,11 @@ async def perform_simple_test(
expected_selected_protocol, expected_selected_protocol,
protocols_for_client, protocols_for_client,
protocols_with_handlers, protocols_with_handlers,
is_host_secure, security_protocol,
): ):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts: async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
for protocol in protocols_with_handlers: for protocol in protocols_with_handlers:
hosts[1].set_stream_handler( hosts[1].set_stream_handler(
protocol, create_echo_stream_handler(ACK_PREFIX) protocol, create_echo_stream_handler(ACK_PREFIX)
@ -38,28 +40,28 @@ async def perform_simple_test(
@pytest.mark.trio @pytest.mark.trio
async def test_single_protocol_succeeds(is_host_secure): async def test_single_protocol_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test( await perform_simple_test(
expected_selected_protocol, expected_selected_protocol,
[expected_selected_protocol], [expected_selected_protocol],
[expected_selected_protocol], [expected_selected_protocol],
is_host_secure, security_protocol,
) )
@pytest.mark.trio @pytest.mark.trio
async def test_single_protocol_fails(is_host_secure): async def test_single_protocol_fails(security_protocol):
with pytest.raises(StreamFailure): with pytest.raises(StreamFailure):
await perform_simple_test( await perform_simple_test(
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure "", [PROTOCOL_ECHO], [PROTOCOL_POTATO], security_protocol
) )
# Cleanup not reached on error # Cleanup not reached on error
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure): async def test_multiple_protocol_first_is_valid_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO expected_selected_protocol = PROTOCOL_ECHO
protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO] protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
@ -67,12 +69,12 @@ async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure):
expected_selected_protocol, expected_selected_protocol,
protocols_for_client, protocols_for_client,
protocols_for_listener, protocols_for_listener,
is_host_secure, security_protocol,
) )
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure): async def test_multiple_protocol_second_is_valid_succeeds(security_protocol):
expected_selected_protocol = PROTOCOL_FOO expected_selected_protocol = PROTOCOL_FOO
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO] protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO] protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
@ -80,15 +82,15 @@ async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure):
expected_selected_protocol, expected_selected_protocol,
protocols_for_client, protocols_for_client,
protocols_for_listener, protocols_for_listener,
is_host_secure, security_protocol,
) )
@pytest.mark.trio @pytest.mark.trio
async def test_multiple_protocol_fails(is_host_secure): async def test_multiple_protocol_fails(security_protocol):
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"] protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(StreamFailure): with pytest.raises(StreamFailure):
await perform_simple_test( await perform_simple_test(
"", protocols_for_client, protocols_for_listener, is_host_secure "", protocols_for_client, protocols_for_listener, security_protocol
) )

View File

@ -82,10 +82,11 @@ async def test_lru_cache_two_nodes(monkeypatch):
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.trio @pytest.mark.trio
@pytest.mark.slow @pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure): async def test_gossipsub_run_with_floodsub_tests(test_case_obj, security_protocol):
await perform_test_from_obj( await perform_test_from_obj(
test_case_obj, test_case_obj,
functools.partial( functools.partial(
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure PubsubFactory.create_batch_with_floodsub,
security_protocol=security_protocol,
), ),
) )

View File

@ -236,7 +236,7 @@ async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
@pytest.mark.trio @pytest.mark.trio
async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): async def test_continuously_read_stream(monkeypatch, nursery, security_protocol):
async def wait_for_event_occurring(event): async def wait_for_event_occurring(event):
await trio.hazmat.checkpoint() await trio.hazmat.checkpoint()
with trio.fail_after(0.1): with trio.fail_after(0.1):
@ -271,8 +271,10 @@ async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc) yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
async with PubsubFactory.create_batch_with_floodsub( async with PubsubFactory.create_batch_with_floodsub(
1, is_secure=is_host_secure 1, security_protocol=security_protocol
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: ) as pubsubs_fsub, net_stream_pair_factory(
security_protocol=security_protocol
) as stream_pair:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Kick off the task `continuously_read_stream` # Kick off the task `continuously_read_stream`
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0]) nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
@ -394,10 +396,12 @@ async def test_handle_talk():
@pytest.mark.trio @pytest.mark.trio
async def test_message_all_peers(monkeypatch, is_host_secure): async def test_message_all_peers(monkeypatch, security_protocol):
async with PubsubFactory.create_batch_with_floodsub( async with PubsubFactory.create_batch_with_floodsub(
1, is_secure=is_host_secure 1, security_protocol=security_protocol
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: ) as pubsubs_fsub, net_stream_pair_factory(
security_protocol=security_protocol
) as stream_pair:
peer_id = IDFactory() peer_id = IDFactory()
mock_peers = {peer_id: stream_pair[0]} mock_peers = {peer_id: stream_pair[0]}
with monkeypatch.context() as m: with monkeypatch.context() as m:

View File

@ -1,88 +1,49 @@
import pytest import pytest
import trio
from libp2p import new_host
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureSession
from libp2p.tools.constants import LISTEN_MADDR from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.tools.utils import connect from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
from libp2p.security.secure_session import SecureSession
# TODO: Add tests for multiple streams being opened on different from libp2p.tools.factories import host_pair_factory
# protocols through the same connection
def peer_id_for_node(node):
return node.get_id()
initiator_key_pair = create_new_key_pair() initiator_key_pair = create_new_key_pair()
noninitiator_key_pair = create_new_key_pair() noninitiator_key_pair = create_new_key_pair()
async def perform_simple_test( async def perform_simple_test(assertion_func, security_protocol):
assertion_func, transports_for_initiator, transports_for_noninitiator async with host_pair_factory(security_protocol=security_protocol) as hosts:
): conn_0 = hosts[0].get_network().connections[hosts[1].get_id()]
# Create libp2p nodes and connect them, then secure the connection, then check conn_1 = hosts[1].get_network().connections[hosts[0].get_id()]
# the proper security was chosen
# TODO: implement -- note we need to introduce the notion of communicating over a raw connection
# for testing, we do NOT want to communicate over a stream so we can't just create two nodes
# and use their conn because our mplex will internally relay messages to a stream
node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator)
node2 = new_host(
key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator
)
async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run(
listen_addrs=[LISTEN_MADDR]
):
await connect(node1, node2)
# Wait a very short period to allow conns to be stored (since the functions
# storing the conns are async, they may happen at slightly different times
# on each node)
await trio.sleep(0.1)
# Get conns
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion # Perform assertion
assertion_func(node1_conn.muxed_conn.secured_conn) assertion_func(conn_0.muxed_conn.secured_conn)
assertion_func(node2_conn.muxed_conn.secured_conn) assertion_func(conn_1.muxed_conn.secured_conn)
@pytest.mark.trio @pytest.mark.trio
async def test_single_insecure_security_transport_succeeds(): @pytest.mark.parametrize(
transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)} "security_protocol, transport_type",
transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)} (
(PLAINTEXT_PROTOCOL_ID, InsecureSession),
(SECIO_PROTOCOL_ID, SecureSession),
(NOISE_PROTOCOL_ID, SecureSession),
),
)
@pytest.mark.trio
async def test_single_insecure_security_transport_succeeds(
security_protocol, transport_type
):
def assertion_func(conn): def assertion_func(conn):
assert isinstance(conn, InsecureSession) assert isinstance(conn, transport_type)
await perform_simple_test( await perform_simple_test(assertion_func, security_protocol)
assertion_func, transports_for_initiator, transports_for_noninitiator
)
@pytest.mark.trio @pytest.mark.trio
async def test_default_insecure_security(): async def test_default_insecure_security():
transports_for_initiator = None
transports_for_noninitiator = None
conn1 = None
conn2 = None
def assertion_func(conn): def assertion_func(conn):
nonlocal conn1 assert isinstance(conn, InsecureSession)
nonlocal conn2
if not conn1:
conn1 = conn
elif not conn2:
conn2 = conn
else:
assert conn1 == conn2
await perform_simple_test( await perform_simple_test(assertion_func, None)
assertion_func, transports_for_initiator, transports_for_noninitiator
)

View File

@ -4,14 +4,18 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
@pytest.fixture @pytest.fixture
async def mplex_conn_pair(is_host_secure): async def mplex_conn_pair(security_protocol):
async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair: async with mplex_conn_pair_factory(
security_protocol=security_protocol
) as mplex_conn_pair:
assert mplex_conn_pair[0].is_initiator assert mplex_conn_pair[0].is_initiator
assert not mplex_conn_pair[1].is_initiator assert not mplex_conn_pair[1].is_initiator
yield mplex_conn_pair[0], mplex_conn_pair[1] yield mplex_conn_pair[0], mplex_conn_pair[1]
@pytest.fixture @pytest.fixture
async def mplex_stream_pair(is_host_secure): async def mplex_stream_pair(security_protocol):
async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair: async with mplex_stream_pair_factory(
security_protocol=security_protocol
) as mplex_stream_pair:
yield mplex_stream_pair yield mplex_stream_pair

View File

@ -6,14 +6,15 @@ import pytest
import trio import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID
from libp2p.tools.factories import HostFactory, PubsubFactory from libp2p.tools.factories import HostFactory, PubsubFactory
from libp2p.tools.interop.daemon import make_p2pd from libp2p.tools.interop.daemon import make_p2pd
from libp2p.tools.interop.utils import connect from libp2p.tools.interop.utils import connect
@pytest.fixture @pytest.fixture
def is_host_secure(): def security_protocol():
return False return PLAINTEXT_PROTOCOL_ID
@pytest.fixture @pytest.fixture
@ -38,7 +39,11 @@ def is_pubsub_signing_strict():
@pytest.fixture @pytest.fixture
async def p2pds( async def p2pds(
num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict num_p2pds,
security_protocol,
is_gossipsub,
is_pubsub_signing,
is_pubsub_signing_strict,
): ):
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
p2pds = [ p2pds = [
@ -46,7 +51,7 @@ async def p2pds(
make_p2pd( make_p2pd(
get_unused_tcp_port(), get_unused_tcp_port(),
get_unused_tcp_port(), get_unused_tcp_port(),
is_host_secure, security_protocol,
is_gossipsub=is_gossipsub, is_gossipsub=is_gossipsub,
is_pubsub_signing=is_pubsub_signing, is_pubsub_signing=is_pubsub_signing,
is_pubsub_signing_strict=is_pubsub_signing_strict, is_pubsub_signing_strict=is_pubsub_signing_strict,
@ -62,14 +67,16 @@ async def p2pds(
@pytest.fixture @pytest.fixture
async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict): async def pubsubs(num_hosts, security_protocol, is_gossipsub, is_pubsub_signing_strict):
if is_gossipsub: if is_gossipsub:
yield PubsubFactory.create_batch_with_gossipsub( yield PubsubFactory.create_batch_with_gossipsub(
num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict num_hosts,
security_protocol=security_protocol,
strict_signing=is_pubsub_signing_strict,
) )
else: else:
yield PubsubFactory.create_batch_with_floodsub( yield PubsubFactory.create_batch_with_floodsub(
num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict num_hosts, security_protocol, strict_signing=is_pubsub_signing_strict
) )
@ -97,8 +104,10 @@ async def is_to_fail_daemon_stream():
@pytest.fixture @pytest.fixture
async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream): async def py_to_daemon_stream_pair(p2pds, security_protocol, is_to_fail_daemon_stream):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: async with HostFactory.create_batch_and_listen(
1, security_protocol=security_protocol
) as hosts:
assert len(p2pds) >= 1 assert len(p2pds) >= 1
host = hosts[0] host = hosts[0]
p2pd = p2pds[0] p2pd = p2pds[0]

View File

@ -6,8 +6,10 @@ from libp2p.tools.interop.utils import connect
@pytest.mark.trio @pytest.mark.trio
async def test_connect(is_host_secure, p2pds): async def test_connect(security_protocol, p2pds):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: async with HostFactory.create_batch_and_listen(
1, security_protocol=security_protocol
) as hosts:
p2pd = p2pds[0] p2pd = p2pds[0]
host = hosts[0] host = hosts[0]
assert len(await p2pd.control.list_peers()) == 0 assert len(await p2pd.control.list_peers()) == 0

View File

@ -6,6 +6,7 @@ import pytest
import trio import trio
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID
from libp2p.tools.factories import HostFactory from libp2p.tools.factories import HostFactory
from libp2p.tools.interop.envs import GO_BIN_PATH from libp2p.tools.interop.envs import GO_BIN_PATH
from libp2p.tools.interop.process import BaseInteractiveProcess from libp2p.tools.interop.process import BaseInteractiveProcess
@ -20,10 +21,10 @@ class EchoProcess(BaseInteractiveProcess):
_peer_info: PeerInfo _peer_info: PeerInfo
def __init__( def __init__(
self, port: int, is_secure: bool, destination: Multiaddr = None self, port: int, security_protocol: TProtocol, destination: Multiaddr = None
) -> None: ) -> None:
args = [f"-l={port}"] args = [f"-l={port}"]
if not is_secure: if security_protocol == PLAINTEXT_PROTOCOL_ID:
args.append("-insecure") args.append("-insecure")
if destination is not None: if destination is not None:
args.append(f"-d={str(destination)}") args.append(f"-d={str(destination)}")
@ -61,9 +62,11 @@ class EchoProcess(BaseInteractiveProcess):
@pytest.mark.trio @pytest.mark.trio
async def test_insecure_conn_py_to_go(is_host_secure): async def test_insecure_conn_py_to_go(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: async with HostFactory.create_batch_and_listen(
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure) 1, security_protocol=security_protocol
) as hosts:
go_proc = EchoProcess(get_unused_tcp_port(), security_protocol)
await go_proc.start() await go_proc.start()
host = hosts[0] host = hosts[0]
@ -78,8 +81,10 @@ async def test_insecure_conn_py_to_go(is_host_secure):
@pytest.mark.trio @pytest.mark.trio
async def test_insecure_conn_go_to_py(is_host_secure): async def test_insecure_conn_go_to_py(security_protocol):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts: async with HostFactory.create_batch_and_listen(
1, security_protocol=security_protocol
) as hosts:
host = hosts[0] host = hosts[0]
expected_data = "Hello, world!\n" expected_data = "Hello, world!\n"
reply_data = "Replyooo!\n" reply_data = "Replyooo!\n"
@ -94,6 +99,6 @@ async def test_insecure_conn_go_to_py(is_host_secure):
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
py_maddr = host.get_addrs()[0] py_maddr = host.get_addrs()[0]
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr) go_proc = EchoProcess(get_unused_tcp_port(), security_protocol, py_maddr)
await go_proc.start() await go_proc.start()
await event_handler_finished.wait() await event_handler_finished.wait()

View File

@ -54,7 +54,7 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) ->
@pytest.mark.parametrize("num_p2pds", (2,)) @pytest.mark.parametrize("num_p2pds", (2,))
@pytest.mark.trio @pytest.mark.trio
async def test_pubsub( async def test_pubsub(
p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery p2pds, is_gossipsub, security_protocol, is_pubsub_signing_strict, nursery
): ):
pubsub_factory = None pubsub_factory = None
if is_gossipsub: if is_gossipsub:
@ -63,7 +63,7 @@ async def test_pubsub(
pubsub_factory = PubsubFactory.create_batch_with_floodsub pubsub_factory = PubsubFactory.create_batch_with_floodsub
async with pubsub_factory( async with pubsub_factory(
1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict 1, security_protocol=security_protocol, strict_signing=is_pubsub_signing_strict
) as pubsubs: ) as pubsubs:
# #
# Test: Recognize pubsub peers on connection. # Test: Recognize pubsub peers on connection.