Merge pull request #346 from ralexstokes/add-tests-for-identify

Add tests for identify
This commit is contained in:
Alex Stokes 2019-11-09 01:51:34 +08:00 committed by GitHub
commit 285bb2ed19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 76 additions and 48 deletions

View File

@ -168,7 +168,11 @@ class BasicHost(IHost):
protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream)
)
except MultiselectError:
except MultiselectError as error:
peer_id = net_stream.muxed_conn.peer_id
logger.debug(
"failed to accept a stream from peer %s, error=%s", peer_id, error
)
await net_stream.reset()
return
net_stream.set_protocol(protocol)

View File

@ -19,24 +19,28 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
return maddr.to_bytes()
def _mk_identify_protobuf(host: IHost) -> Identify:
public_key = host.get_public_key()
laddrs = host.get_addrs()
protocols = host.get_mux().get_protocols()
return Identify(
protocol_version=PROTOCOL_VERSION,
agent_version=AGENT_VERSION,
public_key=public_key.serialize(),
listen_addrs=map(_multiaddr_to_bytes, laddrs),
# TODO send observed address from ``stream``
observed_addr=b"",
protocols=protocols,
)
def identify_handler_for(host: IHost) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
logger.debug("received a request for %s from %s", ID, peer_id)
public_key = host.get_public_key()
laddrs = host.get_addrs()
protocols = host.get_mux().get_protocols()
protobuf = Identify(
protocol_version=PROTOCOL_VERSION,
agent_version=AGENT_VERSION,
public_key=public_key.serialize(),
listen_addrs=map(_multiaddr_to_bytes, laddrs),
# TODO send observed address from ``stream``
observed_addr=b"",
protocols=protocols,
)
protobuf = _mk_identify_protobuf(host)
response = protobuf.SerializeToString()
await stream.write(response)

View File

@ -111,6 +111,9 @@ class SecureSession(BaseSession):
self.high_watermark = len(msg)
async def read(self, n: int = -1) -> bytes:
if n == 0:
return bytes()
data_from_buffer = self._drain(n)
if len(data_from_buffer) > 0:
return data_from_buffer

View File

@ -1,4 +1,5 @@
import asyncio
import logging
from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple
@ -23,6 +24,8 @@ from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
class Mplex(IMuxedConn):
"""
@ -181,7 +184,8 @@ class Mplex(IMuxedConn):
while True:
try:
await self._handle_incoming_message()
except MplexUnavailable:
except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
break
# Force context switch
await asyncio.sleep(0)

View File

@ -1,4 +1,5 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, Tuple
import factory
@ -163,6 +164,14 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
return hosts[0], hosts[1]
@asynccontextmanager
async def pair_of_connected_hosts(is_secure=True):
a, b = await host_pair_factory(is_secure)
yield a, b
close_tasks = (a.close(), b.close())
await asyncio.gather(*close_tasks)
async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:

View File

@ -4,25 +4,18 @@ import secrets
import pytest
from libp2p.host.ping import ID, PING_LENGTH
from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.utils import set_up_nodes_by_transport_opt
from tests.factories import pair_of_connected_hosts
@pytest.mark.asyncio
async def test_ping_once():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
addr = host_a.get_addrs()[0]
info = info_from_p2p_addr(addr)
await host_b.connect(info)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
await stream.close()
async with pair_of_connected_hosts() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
await stream.close()
SOME_PING_COUNT = 3
@ -30,21 +23,15 @@ SOME_PING_COUNT = 3
@pytest.mark.asyncio
async def test_ping_several():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
addr = host_a.get_addrs()[0]
info = info_from_p2p_addr(addr)
await host_b.connect(info)
stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
# NOTE: simulate some time to sleep to mirror a real
# world usage where a peer sends pings on some periodic interval
# NOTE: this interval can be `0` for this test.
await asyncio.sleep(0)
await stream.close()
async with pair_of_connected_hosts() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
# NOTE: simulate some time to sleep to mirror a real
# world usage where a peer sends pings on some periodic interval
# NOTE: this interval can be `0` for this test.
await asyncio.sleep(0)
await stream.close()

View File

@ -0,0 +1,17 @@
import pytest
from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
from tests.factories import pair_of_connected_hosts
@pytest.mark.asyncio
async def test_identify_protocol():
async with pair_of_connected_hosts() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read()
await stream.close()
identify_response = Identify()
identify_response.ParseFromString(response)
assert identify_response == _mk_identify_protobuf(host_a)