0548d28568
Since we don't catch `Exception` in the stream handlers, catch them in the stream handlers in the tests.
286 lines
11 KiB
Python
286 lines
11 KiB
Python
import multiaddr
|
|
import pytest
|
|
|
|
from libp2p.network.stream.exceptions import StreamError
|
|
from libp2p.tools.constants import MAX_READ_LEN
|
|
from libp2p.tools.factories import HostFactory
|
|
from libp2p.tools.utils import connect, create_echo_stream_handler
|
|
from libp2p.typing import TProtocol
|
|
|
|
PROTOCOL_ID_0 = TProtocol("/echo/0")
|
|
PROTOCOL_ID_1 = TProtocol("/echo/1")
|
|
PROTOCOL_ID_2 = TProtocol("/echo/2")
|
|
PROTOCOL_ID_3 = TProtocol("/echo/3")
|
|
|
|
ACK_STR_0 = "ack_0:"
|
|
ACK_STR_1 = "ack_1:"
|
|
ACK_STR_2 = "ack_2:"
|
|
ACK_STR_3 = "ack_3:"
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_simple_messages(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
|
|
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
|
|
messages = ["hello" + str(x) for x in range(10)]
|
|
for message in messages:
|
|
await stream.write(message.encode())
|
|
response = (await stream.read(MAX_READ_LEN)).decode()
|
|
assert response == (ACK_STR_0 + message)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_double_response(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
|
|
async def double_response_stream_handler(stream):
|
|
while True:
|
|
try:
|
|
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
|
except StreamError:
|
|
break
|
|
|
|
response = ACK_STR_0 + read_string
|
|
try:
|
|
await stream.write(response.encode())
|
|
except StreamError:
|
|
break
|
|
|
|
response = ACK_STR_1 + read_string
|
|
try:
|
|
await stream.write(response.encode())
|
|
except StreamError:
|
|
break
|
|
|
|
hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
|
|
messages = ["hello" + str(x) for x in range(10)]
|
|
for message in messages:
|
|
await stream.write(message.encode())
|
|
|
|
response1 = (await stream.read(MAX_READ_LEN)).decode()
|
|
assert response1 == (ACK_STR_0 + message)
|
|
|
|
response2 = (await stream.read(MAX_READ_LEN)).decode()
|
|
assert response2 == (ACK_STR_1 + message)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_multiple_streams(is_host_secure):
|
|
# hosts[0] should be able to open a stream with hosts[1] and then vice versa.
|
|
# Stream IDs should be generated uniquely so that the stream state is not overwritten
|
|
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
hosts[0].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
|
)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
|
|
|
stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
|
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
|
|
|
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
|
messages = ["hello" + str(x) for x in range(10)]
|
|
for message in messages:
|
|
a_message = message + "_a"
|
|
b_message = message + "_b"
|
|
|
|
await stream_a.write(a_message.encode())
|
|
await stream_b.write(b_message.encode())
|
|
|
|
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
|
|
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
|
|
|
|
assert response_a == (ACK_STR_1 + a_message) and response_b == (
|
|
ACK_STR_0 + b_message
|
|
)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_multiple_streams_same_initiator_different_protocols(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
|
)
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
|
)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
|
|
|
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
|
|
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
|
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
|
|
|
|
messages = ["hello" + str(x) for x in range(10)]
|
|
for message in messages:
|
|
a1_message = message + "_a1"
|
|
a2_message = message + "_a2"
|
|
a3_message = message + "_a3"
|
|
|
|
await stream_a1.write(a1_message.encode())
|
|
await stream_a2.write(a2_message.encode())
|
|
await stream_a3.write(a3_message.encode())
|
|
|
|
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
|
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
|
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
|
|
|
|
assert (
|
|
response_a1 == (ACK_STR_0 + a1_message)
|
|
and response_a2 == (ACK_STR_1 + a2_message)
|
|
and response_a3 == (ACK_STR_2 + a3_message)
|
|
)
|
|
|
|
# Success, terminate pending tasks.
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_multiple_streams_two_initiators(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
hosts[0].set_stream_handler(
|
|
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
|
)
|
|
hosts[0].set_stream_handler(
|
|
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
|
|
)
|
|
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
|
)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
|
|
|
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
|
|
|
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
|
|
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
|
|
|
|
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
|
messages = ["hello" + str(x) for x in range(10)]
|
|
for message in messages:
|
|
a1_message = message + "_a1"
|
|
a2_message = message + "_a2"
|
|
|
|
b1_message = message + "_b1"
|
|
b2_message = message + "_b2"
|
|
|
|
await stream_a1.write(a1_message.encode())
|
|
await stream_a2.write(a2_message.encode())
|
|
|
|
await stream_b1.write(b1_message.encode())
|
|
await stream_b2.write(b2_message.encode())
|
|
|
|
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
|
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
|
|
|
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
|
|
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
|
|
|
|
assert (
|
|
response_a1 == (ACK_STR_0 + a1_message)
|
|
and response_a2 == (ACK_STR_1 + a2_message)
|
|
and response_b1 == (ACK_STR_2 + b1_message)
|
|
and response_b2 == (ACK_STR_3 + b2_message)
|
|
)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_triangle_nodes_connection(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts:
|
|
|
|
hosts[0].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
hosts[1].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
hosts[2].set_stream_handler(
|
|
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
|
)
|
|
|
|
# Associate the peer with local ip address (see default parameters of Libp2p())
|
|
# Associate all permutations
|
|
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
|
|
|
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
|
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
|
|
|
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
|
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
|
|
|
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
|
|
|
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
|
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
|
|
|
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
|
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
|
|
|
messages = ["hello" + str(x) for x in range(5)]
|
|
streams = [
|
|
stream_0_to_1,
|
|
stream_0_to_2,
|
|
stream_1_to_0,
|
|
stream_1_to_2,
|
|
stream_2_to_0,
|
|
stream_2_to_1,
|
|
]
|
|
|
|
for message in messages:
|
|
for stream in streams:
|
|
await stream.write(message.encode())
|
|
response = (await stream.read(MAX_READ_LEN)).decode()
|
|
assert response == (ACK_STR_0 + message)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_host_connect(is_host_secure):
|
|
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
|
assert len(hosts[0].get_peerstore().peer_ids()) == 1
|
|
|
|
await connect(hosts[0], hosts[1])
|
|
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
|
|
|
await connect(hosts[0], hosts[1])
|
|
# make sure we don't do double connection
|
|
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
|
|
|
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
|
|
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
|
|
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
|
|
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()
|