Merge pull request #346 from ralexstokes/add-tests-for-identify

Add tests for identify
This commit is contained in:
Alex Stokes 2019-11-09 01:51:34 +08:00 committed by GitHub
commit 285bb2ed19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 76 additions and 48 deletions

View File

@ -168,7 +168,11 @@ class BasicHost(IHost):
protocol, handler = await self.multiselect.negotiate( protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream) MultiselectCommunicator(net_stream)
) )
except MultiselectError: except MultiselectError as error:
peer_id = net_stream.muxed_conn.peer_id
logger.debug(
"failed to accept a stream from peer %s, error=%s", peer_id, error
)
await net_stream.reset() await net_stream.reset()
return return
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)

View File

@ -19,16 +19,12 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
return maddr.to_bytes() return maddr.to_bytes()
def identify_handler_for(host: IHost) -> StreamHandlerFn: def _mk_identify_protobuf(host: IHost) -> Identify:
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
logger.debug("received a request for %s from %s", ID, peer_id)
public_key = host.get_public_key() public_key = host.get_public_key()
laddrs = host.get_addrs() laddrs = host.get_addrs()
protocols = host.get_mux().get_protocols() protocols = host.get_mux().get_protocols()
protobuf = Identify( return Identify(
protocol_version=PROTOCOL_VERSION, protocol_version=PROTOCOL_VERSION,
agent_version=AGENT_VERSION, agent_version=AGENT_VERSION,
public_key=public_key.serialize(), public_key=public_key.serialize(),
@ -37,6 +33,14 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
observed_addr=b"", observed_addr=b"",
protocols=protocols, protocols=protocols,
) )
def identify_handler_for(host: IHost) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
logger.debug("received a request for %s from %s", ID, peer_id)
protobuf = _mk_identify_protobuf(host)
response = protobuf.SerializeToString() response = protobuf.SerializeToString()
await stream.write(response) await stream.write(response)

View File

@ -111,6 +111,9 @@ class SecureSession(BaseSession):
self.high_watermark = len(msg) self.high_watermark = len(msg)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
if n == 0:
return bytes()
data_from_buffer = self._drain(n) data_from_buffer = self._drain(n)
if len(data_from_buffer) > 0: if len(data_from_buffer) > 0:
return data_from_buffer return data_from_buffer

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
from typing import Any # noqa: F401 from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple from typing import Awaitable, Dict, List, Optional, Tuple
@ -23,6 +24,8 @@ from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
class Mplex(IMuxedConn): class Mplex(IMuxedConn):
""" """
@ -181,7 +184,8 @@ class Mplex(IMuxedConn):
while True: while True:
try: try:
await self._handle_incoming_message() await self._handle_incoming_message()
except MplexUnavailable: except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
break break
# Force context switch # Force context switch
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from contextlib import asynccontextmanager
from typing import Dict, Tuple from typing import Dict, Tuple
import factory import factory
@ -163,6 +164,14 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
return hosts[0], hosts[1] return hosts[0], hosts[1]
@asynccontextmanager
async def pair_of_connected_hosts(is_secure=True):
a, b = await host_pair_factory(is_secure)
yield a, b
close_tasks = (a.close(), b.close())
await asyncio.gather(*close_tasks)
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:

View File

@ -4,19 +4,12 @@ import secrets
import pytest import pytest
from libp2p.host.ping import ID, PING_LENGTH from libp2p.host.ping import ID, PING_LENGTH
from libp2p.peer.peerinfo import info_from_p2p_addr from tests.factories import pair_of_connected_hosts
from tests.utils import set_up_nodes_by_transport_opt
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_once(): async def test_ping_once():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async with pair_of_connected_hosts() as (host_a, host_b):
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
addr = host_a.get_addrs()[0]
info = info_from_p2p_addr(addr)
await host_b.connect(info)
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping) await stream.write(some_ping)
@ -30,13 +23,7 @@ SOME_PING_COUNT = 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_several(): async def test_ping_several():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] async with pair_of_connected_hosts() as (host_a, host_b):
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
addr = host_a.get_addrs()[0]
info = info_from_p2p_addr(addr)
await host_b.connect(info)
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT): for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)

View File

@ -0,0 +1,17 @@
import pytest
from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
from tests.factories import pair_of_connected_hosts
@pytest.mark.asyncio
async def test_identify_protocol():
async with pair_of_connected_hosts() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read()
await stream.close()
identify_response = Identify()
identify_response.ParseFromString(response)
assert identify_response == _mk_identify_protobuf(host_a)