Merge pull request #337 from ralexstokes/ground-work-for-identify-to-host

Ground work for identify to host
This commit is contained in:
Alex Stokes 2019-11-07 23:50:36 +08:00 committed by GitHub
commit cce33b2f50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 88 additions and 50 deletions

View File

@ -77,7 +77,6 @@ def initialize_default_swarm(
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None, sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None, peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None,
) -> Swarm: ) -> Swarm:
""" """
initialize swarm when no swarm is passed in. initialize swarm when no swarm is passed in.
@ -87,7 +86,6 @@ def initialize_default_swarm(
:param muxer_opt: optional choice of stream muxer :param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade :param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore :param peerstore_opt: optional peerstore
:param disc_opt: optional discovery
:return: return a default swarm instance :return: return a default swarm instance
""" """
@ -147,16 +145,15 @@ async def new_node(
muxer_opt=muxer_opt, muxer_opt=muxer_opt,
sec_opt=sec_opt, sec_opt=sec_opt,
peerstore_opt=peerstore_opt, peerstore_opt=peerstore_opt,
disc_opt=disc_opt,
) )
# TODO enable support for other host type # TODO enable support for other host type
# TODO routing unimplemented # TODO routing unimplemented
host: IHost # If not explicitly typed, MyPy raises error host: IHost # If not explicitly typed, MyPy raises error
if disc_opt: if disc_opt:
host = RoutedHost(swarm_opt, disc_opt) host = RoutedHost(key_pair.public_key, swarm_opt, disc_opt)
else: else:
host = BasicHost(swarm_opt) host = BasicHost(key_pair.public_key, swarm_opt)
# Kick off cleanup job # Kick off cleanup job
asyncio.ensure_future(cleanup_done_tasks()) asyncio.ensure_future(cleanup_done_tasks())

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Sequence
import multiaddr import multiaddr
from libp2p.crypto.keys import PublicKey
from libp2p.host.defaults import get_default_protocols from libp2p.host.defaults import get_default_protocols
from libp2p.host.exceptions import StreamFailure from libp2p.host.exceptions import StreamFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
@ -38,6 +39,7 @@ class BasicHost(IHost):
right after a stream is initialized. right after a stream is initialized.
""" """
_public_key: PublicKey
_network: INetwork _network: INetwork
peerstore: IPeerStore peerstore: IPeerStore
@ -46,14 +48,16 @@ class BasicHost(IHost):
def __init__( def __init__(
self, self,
public_key: PublicKey,
network: INetwork, network: INetwork,
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None, default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
) -> None: ) -> None:
self._public_key = public_key
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler) self._network.set_stream_handler(self._swarm_stream_handler)
self.peerstore = self._network.peerstore self.peerstore = self._network.peerstore
# Protocol muxing # Protocol muxing
default_protocols = default_protocols or get_default_protocols() default_protocols = default_protocols or get_default_protocols(self)
self.multiselect = Multiselect(default_protocols) self.multiselect = Multiselect(default_protocols)
self.multiselect_client = MultiselectClient() self.multiselect_client = MultiselectClient()
@ -63,6 +67,9 @@ class BasicHost(IHost):
""" """
return self._network.get_peer_id() return self._network.get_peer_id()
def get_public_key(self) -> PublicKey:
return self._public_key
def get_network(self) -> INetwork: def get_network(self) -> INetwork:
""" """
:return: network instance of host :return: network instance of host

View File

@ -1,11 +1,11 @@
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from libp2p.host.host_interface import IHost
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.typing import TProtocol, StreamHandlerFn from libp2p.typing import TProtocol, StreamHandlerFn
DEFAULT_HOST_PROTOCOLS: "OrderedDict[TProtocol, StreamHandlerFn]" = OrderedDict()
def get_default_protocols(host: IHost) -> "OrderedDict[TProtocol, StreamHandlerFn]":
def get_default_protocols() -> "OrderedDict[TProtocol, StreamHandlerFn]": return OrderedDict()
return DEFAULT_HOST_PROTOCOLS.copy()

View File

@ -3,6 +3,7 @@ from typing import Any, List, Sequence
import multiaddr import multiaddr
from libp2p.crypto.keys import PublicKey
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -17,6 +18,12 @@ class IHost(ABC):
:return: peer_id of host :return: peer_id of host
""" """
@abstractmethod
def get_public_key(self) -> PublicKey:
"""
:return: the public key belonging to the peer
"""
@abstractmethod @abstractmethod
def get_network(self) -> INetwork: def get_network(self) -> INetwork:
""" """

View File

@ -1,3 +1,4 @@
from libp2p.crypto.keys import PublicKey
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.exceptions import ConnectionFailure from libp2p.host.exceptions import ConnectionFailure
from libp2p.network.network_interface import INetwork from libp2p.network.network_interface import INetwork
@ -10,8 +11,8 @@ from libp2p.routing.interfaces import IPeerRouting
class RoutedHost(BasicHost): class RoutedHost(BasicHost):
_router: IPeerRouting _router: IPeerRouting
def __init__(self, network: INetwork, router: IPeerRouting): def __init__(self, public_key: PublicKey, network: INetwork, router: IPeerRouting):
super().__init__(network) super().__init__(public_key, network)
self._router = router self._router = router
async def connect(self, peer_info: PeerInfo) -> None: async def connect(self, peer_info: PeerInfo) -> None:

View File

@ -1,11 +1,10 @@
import logging import logging
from typing import Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.crypto.keys import PublicKey from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn
from .pb.identify_pb2 import Identify from .pb.identify_pb2 import Identify
@ -20,13 +19,15 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
return maddr.to_bytes() return maddr.to_bytes()
def identify_handler_for( def identify_handler_for(host: IHost) -> StreamHandlerFn:
public_key: PublicKey, laddrs: Sequence[Multiaddr], protocols: Sequence[TProtocol]
) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None: async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
logger.debug("received a request for %s from %s", ID, peer_id) logger.debug("received a request for %s from %s", ID, peer_id)
public_key = host.get_public_key()
laddrs = host.get_addrs()
protocols = host.get_mux().get_protocols()
protobuf = Identify( protobuf = Identify(
protocol_version=PROTOCOL_VERSION, protocol_version=PROTOCOL_VERSION,
agent_version=AGENT_VERSION, agent_version=AGENT_VERSION,

View File

@ -22,6 +22,9 @@ class IMultiselectMuxer(ABC):
:param handler: handler function :param handler: handler function
""" """
def get_protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.handlers.keys())
@abstractmethod @abstractmethod
async def negotiate( async def negotiate(
self, communicator: IMultiselectCommunicator self, communicator: IMultiselectCommunicator

View File

@ -3,12 +3,13 @@ from typing import Dict, Tuple
import factory import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
@ -17,7 +18,9 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
@ -37,21 +40,37 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm: class SwarmFactory(factory.Factory):
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
class ListeningSwarmFactory(factory.Factory):
class Meta: class Meta:
model = Swarm model = Swarm
class Params:
is_secure = False
key_pair = factory.LazyFunction(generate_new_rsa_identity)
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
peer_id = factory.LazyAttribute(lambda o: generate_peer_id_from(o.key_pair))
peerstore = factory.LazyFunction(PeerStore)
upgrader = factory.LazyAttribute(
lambda o: TransportUpgrader(
security_transport_factory(o.is_secure, o.key_pair), o.muxer_opt
)
)
transport = factory.LazyFunction(TCP)
@classmethod @classmethod
async def create_and_listen( async def create_and_listen(
cls, is_secure: bool, muxer_opt: TMuxerOptions = None cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm: ) -> Swarm:
swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt) # `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it.
optional_kwargs = {}
if key_pair is not None:
optional_kwargs["key_pair"] = key_pair
if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs)
await swarm.listen(LISTEN_MADDR) await swarm.listen(LISTEN_MADDR)
return swarm return swarm
@ -61,7 +80,7 @@ class ListeningSwarmFactory(factory.Factory):
) -> Tuple[Swarm, ...]: ) -> Tuple[Swarm, ...]:
return await asyncio.gather( return await asyncio.gather(
*[ *[
cls.create_and_listen(is_secure, muxer_opt=muxer_opt) cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
for _ in range(number) for _ in range(number)
] ]
) )
@ -73,20 +92,28 @@ class HostFactory(factory.Factory):
class Params: class Params:
is_secure = False is_secure = False
key_pair = factory.LazyFunction(generate_new_rsa_identity)
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key)
network = factory.LazyAttribute(
@classmethod lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
async def create_and_listen(cls, is_secure: bool) -> BasicHost: )
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1)
return BasicHost(swarms[0])
@classmethod @classmethod
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]: ) -> Tuple[BasicHost, ...]:
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number) key_pairs = [generate_new_rsa_identity() for _ in range(number)]
return tuple(BasicHost(swarm) for swarm in range(swarms)) swarms = await asyncio.gather(
*[
SwarmFactory.create_and_listen(is_secure, key_pair)
for key_pair in key_pairs
]
)
return tuple(
BasicHost(key_pair.public_key, swarm)
for key_pair, swarm in zip(key_pairs, swarms)
)
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
@ -123,7 +150,7 @@ class PubsubFactory(factory.Factory):
async def swarm_pair_factory( async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]: ) -> Tuple[Swarm, Swarm]:
swarms = await ListeningSwarmFactory.create_batch_and_listen( swarms = await SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt is_secure, 2, muxer_opt=muxer_opt
) )
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
@ -131,12 +158,7 @@ async def swarm_pair_factory(
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather( hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
*[
HostFactory.create_and_listen(is_secure),
HostFactory.create_and_listen(is_secure),
]
)
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])
return hosts[0], hosts[1] return hosts[0], hosts[1]

View File

@ -7,8 +7,8 @@ from libp2p.host.defaults import get_default_protocols
def test_default_protocols(): def test_default_protocols():
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
swarm = initialize_default_swarm(key_pair) swarm = initialize_default_swarm(key_pair)
host = BasicHost(swarm) host = BasicHost(key_pair.public_key, swarm)
mux = host.get_mux() mux = host.get_mux()
handlers = mux.handlers handlers = mux.handlers
assert handlers == get_default_protocols() assert handlers == get_default_protocols(host)

View File

@ -56,7 +56,7 @@ class MyNotifee(INotifee):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_notify(is_host_secure): async def test_notify(is_host_secure):
swarms = [SwarmFactory(is_host_secure) for _ in range(2)] swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)]
events_0_0 = [] events_0_0 = []
events_1_0 = [] events_1_0 = []

View File

@ -3,13 +3,13 @@ import asyncio
import pytest import pytest
from libp2p.network.exceptions import SwarmException from libp2p.network.exceptions import SwarmException
from tests.factories import ListeningSwarmFactory from tests.factories import SwarmFactory
from tests.utils import connect_swarm from tests.utils import connect_swarm
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_swarm_dial_peer(is_host_secure): async def test_swarm_dial_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# Test: No addr found. # Test: No addr found.
with pytest.raises(SwarmException): with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
@ -41,7 +41,7 @@ async def test_swarm_dial_peer(is_host_secure):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_swarm_close_peer(is_host_secure): async def test_swarm_close_peer(is_host_secure):
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2 # 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2]) await connect_swarm(swarms[1], swarms[2])