Merge dbe0df2200e9646878f03e3bb24ecfafd6f8aeb0 into 1f881e04648f296e4eb89450ecd8333438c3d2d3

This commit is contained in:
aratz-lasa 2020-02-28 09:05:26 +01:00 committed by GitHub
commit de3a44eea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 35 additions and 27 deletions

View File

@ -40,7 +40,6 @@ def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
class Swarm(Service, INetworkService):
self_id: ID
peerstore: IPeerStore
upgrader: TransportUpgrader
@ -276,7 +275,9 @@ class Swarm(Service, INetworkService):
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
await listener.listen(maddr, self.listener_nursery)
await self.listener_nursery.start(
listener.listen, maddr # type: ignore
)
# Call notifiers since event occurred
await self.notify_listen(maddr)

View File

@ -391,7 +391,7 @@ class GossipSub(IPubsubRouter, Service):
await trio.sleep(self.heartbeat_interval)
def mesh_heartbeat(
self
self,
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)

View File

@ -84,7 +84,7 @@ def noise_transport_factory() -> NoiseTransport:
@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
nursery: trio.Nursery,
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
@ -98,7 +98,7 @@ async def raw_conn_factory(
tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr)
await event.wait()
@ -401,7 +401,7 @@ async def swarm_pair_factory(
@asynccontextmanager
async def host_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
@ -420,7 +420,7 @@ async def swarm_conn_pair_factory(
@asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
@ -432,7 +432,7 @@ async def mplex_conn_pair_factory(
@asynccontextmanager
async def mplex_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
@ -448,7 +448,7 @@ async def mplex_stream_pair_factory(
@asynccontextmanager
async def net_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")

View File

@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:
def create_echo_stream_handler(
ack_prefix: str
ack_prefix: str,
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple
from typing import Any, Tuple
from multiaddr import Multiaddr
import trio
@ -7,9 +7,12 @@ import trio
class IListener(ABC):
@abstractmethod
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
async def listen(
self, maddr: Multiaddr, task_status: Any = trio.TASK_STATUS_IGNORED
) -> bool:
"""
put listener in listening mode and wait for incoming connections.
put listener in listening mode and wait for incoming connections. It
blocks until it stops to listen.
:param maddr: multiaddr of peer
:return: return True if successful

View File

@ -1,5 +1,5 @@
import logging
from typing import Awaitable, Callable, List, Sequence, Tuple
from typing import Any, Awaitable, Callable, List, Sequence, Tuple
from multiaddr import Multiaddr
import trio
@ -23,10 +23,12 @@ class TCPListener(IListener):
self.listeners = []
self.handler = handler_function
# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
async def listen(
self, maddr: Multiaddr, task_status: TaskStatus[Any] = trio.TASK_STATUS_IGNORED
) -> None:
"""
put listener in listening mode and wait for incoming connections.
put listener in listening mode and wait for incoming connections. It
blocks until it stops to listen.
:param maddr: maddr of peer
:return: return True if successful
@ -46,13 +48,15 @@ class TCPListener(IListener):
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)
listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
)
self.listeners.extend(listeners)
async with trio.open_nursery() as nursery:
listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
)
task_status.started()
self.listeners.extend(listeners)
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""

View File

@ -17,9 +17,9 @@ async def test_tcp_listener(nursery):
listener = transport.create_listener(handler)
assert len(listener.get_addrs()) == 0
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
assert len(listener.get_addrs()) == 1
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
assert len(listener.get_addrs()) == 2
@ -41,7 +41,7 @@ async def test_tcp_dial(nursery):
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]