diff --git a/network/connection/raw_connection.py b/network/connection/raw_connection.py index dc43f15..1700469 100644 --- a/network/connection/raw_connection.py +++ b/network/connection/raw_connection.py @@ -2,13 +2,24 @@ from .raw_connection_interface import IRawConnection class RawConnection(IRawConnection): - # pylint: disable=too-few-public-methods - def __init__(self, ip, port, reader, writer): + def __init__(self, ip, port, reader, writer, initiator): + # pylint: disable=too-many-arguments self.conn_ip = ip self.conn_port = port self.reader = reader self.writer = writer + self._next_id = 0 if initiator else 1 + self.initiator = initiator def close(self): self.writer.close() + + def next_stream_id(self): + """ + Get next available stream id + :return: next available stream id for the connection + """ + next_id = self._next_id + self._next_id += 2 + return next_id diff --git a/network/swarm.py b/network/swarm.py index a0a39eb..7a73209 100644 --- a/network/swarm.py +++ b/network/swarm.py @@ -59,7 +59,7 @@ class Swarm(INetwork): raw_conn = await self.transport.dial(multiaddr) # Use upgrader to upgrade raw conn to muxed conn - muxed_conn = self.upgrader.upgrade_connection(raw_conn, True) + muxed_conn = self.upgrader.upgrade_connection(raw_conn) # Store muxed connection in connections self.connections[peer_id] = muxed_conn @@ -118,8 +118,8 @@ class Swarm(INetwork): # Upgrade reader/write to a net_stream and pass \ # to appropriate stream handler (using multiaddr) raw_conn = RawConnection(multiaddr.value_for_protocol('ip4'), - multiaddr.value_for_protocol('tcp'), reader, writer) - muxed_conn = self.upgrader.upgrade_connection(raw_conn, False) + multiaddr.value_for_protocol('tcp'), reader, writer, False) + muxed_conn = self.upgrader.upgrade_connection(raw_conn) # TODO: Remove protocol id from muxed_conn accept stream or # move protocol muxing into accept_stream diff --git a/stream_muxer/mplex/mplex.py b/stream_muxer/mplex/mplex.py index a65ecb2..bd324f5 100644 --- a/stream_muxer/mplex/mplex.py +++ b/stream_muxer/mplex/mplex.py @@ -10,36 +10,26 @@ class Mplex(IMuxedConn): reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ - def __init__(self, conn, initiator): + def __init__(self, conn): """ create a new muxed connection :param conn: an instance of raw connection :param initiator: boolean to prevent multiplex with self """ self.raw_conn = conn - self.initiator = initiator + self.initiator = conn.initiator # Mapping from stream ID -> buffer of messages for that stream self.buffers = {} self.stream_queue = asyncio.Queue() - self._next_id = 0 if self.initiator else 1 self.data_buffer = bytearray() - # The initiator need not read upon construction time. + # The initiator of the raw connection need not read upon construction time. # It should read when the user decides that it wants to read from the constructed stream. - if not initiator: + if not self.initiator: asyncio.ensure_future(self.handle_incoming()) - def _next_stream_id(self): - """ - Get next available stream id - :return: next available stream id for the connection - """ - next_id = self._next_id - self._next_id += 2 - return next_id - def close(self): """ close the stream muxer and underlying raw connection @@ -88,7 +78,7 @@ class Mplex(IMuxedConn): :param multi_addr: multi_addr that stream connects to :return: a new stream """ - stream_id = self._next_stream_id() + stream_id = self.raw_conn.next_stream_id() stream = MplexStream(stream_id, multi_addr, self) self.buffers[stream_id] = asyncio.Queue() return stream diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index f449e2e..863d4c2 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -80,6 +80,54 @@ async def test_double_response(): # Success, terminate pending tasks. return +@pytest.mark.asyncio +async def test_multiple_streams(): + # Node A should be able to open a stream with node B and then vice versa. + # Stream IDs should be generated uniquely so that the stream state is not overwritten + node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8004"]) + node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8005"]) + + async def stream_handler_a(stream): + while True: + read_string = (await stream.read()).decode() + + response = "ack_a:" + read_string + await stream.write(response.encode()) + + async def stream_handler_b(stream): + while True: + read_string = (await stream.read()).decode() + + response = "ack_b:" + read_string + await stream.write(response.encode()) + + node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a) + node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) + node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10) + + stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"]) + stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"]) + + # A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + a_message = message + "_a" + b_message = message + "_b" + + await stream_a.write(a_message.encode()) + await stream_b.write(b_message.encode()) + + response_a = (await stream_a.read()).decode() + response_b = (await stream_b.read()).decode() + + assert response_a == ("ack_b:" + a_message) and response_b == ("ack_a:" + b_message) + + # Success, terminate pending tasks. + return + @pytest.mark.asyncio async def test_host_connect(): node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8001/"]) diff --git a/tests/network/test_connection.py b/tests/network/test_connection.py index c009648..15ebb96 100644 --- a/tests/network/test_connection.py +++ b/tests/network/test_connection.py @@ -1,9 +1,6 @@ import asyncio import pytest -# from network.connection.raw_connection import RawConnection - - async def handle_echo(reader, writer): data = await reader.read(100) writer.write(data) @@ -20,7 +17,6 @@ async def test_simple_echo(): await asyncio.start_server(handle_echo, server_ip, server_port) reader, writer = await asyncio.open_connection(server_ip, server_port) - # raw_connection = RawConnection(server_ip, server_port, reader, writer) test_message = "hello world" writer.write(test_message.encode()) diff --git a/transport/tcp/tcp.py b/transport/tcp/tcp.py index 4e6f49e..24ad69e 100644 --- a/transport/tcp/tcp.py +++ b/transport/tcp/tcp.py @@ -69,7 +69,7 @@ class TCP(ITransport): reader, writer = await asyncio.open_connection(host, port) - return RawConnection(host, port, reader, writer) + return RawConnection(host, port, reader, writer, True) def create_listener(self, handler_function, options=None): """ diff --git a/transport/upgrader.py b/transport/upgrader.py index 5a22c29..995e670 100644 --- a/transport/upgrader.py +++ b/transport/upgrader.py @@ -17,11 +17,11 @@ class TransportUpgrader(): def upgrade_security(self): pass - def upgrade_connection(self, conn, initiator): + def upgrade_connection(self, conn): """ upgrade raw connection to muxed connection """ # For PoC, no security, default to mplex # TODO do exchange to determine multiplexer - return Mplex(conn, initiator) + return Mplex(conn)