Merge pull request #91 from zixuanzh/next-stream-id
Refactoring stream IDs and notion of initiator
This commit is contained in:
commit
f54ac8aaa1
@ -2,13 +2,24 @@ from .raw_connection_interface import IRawConnection
|
||||
|
||||
|
||||
class RawConnection(IRawConnection):
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
||||
def __init__(self, ip, port, reader, writer):
|
||||
def __init__(self, ip, port, reader, writer, initiator):
|
||||
# pylint: disable=too-many-arguments
|
||||
self.conn_ip = ip
|
||||
self.conn_port = port
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self._next_id = 0 if initiator else 1
|
||||
self.initiator = initiator
|
||||
|
||||
def close(self):
|
||||
self.writer.close()
|
||||
|
||||
def next_stream_id(self):
|
||||
"""
|
||||
Get next available stream id
|
||||
:return: next available stream id for the connection
|
||||
"""
|
||||
next_id = self._next_id
|
||||
self._next_id += 2
|
||||
return next_id
|
||||
|
@ -59,7 +59,7 @@ class Swarm(INetwork):
|
||||
raw_conn = await self.transport.dial(multiaddr)
|
||||
|
||||
# Use upgrader to upgrade raw conn to muxed conn
|
||||
muxed_conn = self.upgrader.upgrade_connection(raw_conn, True)
|
||||
muxed_conn = self.upgrader.upgrade_connection(raw_conn)
|
||||
|
||||
# Store muxed connection in connections
|
||||
self.connections[peer_id] = muxed_conn
|
||||
@ -118,8 +118,8 @@ class Swarm(INetwork):
|
||||
# Upgrade reader/write to a net_stream and pass \
|
||||
# to appropriate stream handler (using multiaddr)
|
||||
raw_conn = RawConnection(multiaddr.value_for_protocol('ip4'),
|
||||
multiaddr.value_for_protocol('tcp'), reader, writer)
|
||||
muxed_conn = self.upgrader.upgrade_connection(raw_conn, False)
|
||||
multiaddr.value_for_protocol('tcp'), reader, writer, False)
|
||||
muxed_conn = self.upgrader.upgrade_connection(raw_conn)
|
||||
|
||||
# TODO: Remove protocol id from muxed_conn accept stream or
|
||||
# move protocol muxing into accept_stream
|
||||
|
@ -10,36 +10,26 @@ class Mplex(IMuxedConn):
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
|
||||
"""
|
||||
|
||||
def __init__(self, conn, initiator):
|
||||
def __init__(self, conn):
|
||||
"""
|
||||
create a new muxed connection
|
||||
:param conn: an instance of raw connection
|
||||
:param initiator: boolean to prevent multiplex with self
|
||||
"""
|
||||
self.raw_conn = conn
|
||||
self.initiator = initiator
|
||||
self.initiator = conn.initiator
|
||||
|
||||
# Mapping from stream ID -> buffer of messages for that stream
|
||||
self.buffers = {}
|
||||
|
||||
self.stream_queue = asyncio.Queue()
|
||||
self._next_id = 0
|
||||
self.data_buffer = bytearray()
|
||||
|
||||
# The initiator need not read upon construction time.
|
||||
# The initiator of the raw connection need not read upon construction time.
|
||||
# It should read when the user decides that it wants to read from the constructed stream.
|
||||
if not initiator:
|
||||
if not self.initiator:
|
||||
asyncio.ensure_future(self.handle_incoming())
|
||||
|
||||
def _next_stream_id(self):
|
||||
"""
|
||||
Get next available stream id
|
||||
:return: next available stream id for the connection
|
||||
"""
|
||||
next_id = self._next_id
|
||||
self._next_id += 1
|
||||
return next_id
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
close the stream muxer and underlying raw connection
|
||||
@ -88,7 +78,7 @@ class Mplex(IMuxedConn):
|
||||
:param multi_addr: multi_addr that stream connects to
|
||||
:return: a new stream
|
||||
"""
|
||||
stream_id = self._next_stream_id()
|
||||
stream_id = self.raw_conn.next_stream_id()
|
||||
stream = MplexStream(stream_id, multi_addr, self)
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
return stream
|
||||
|
@ -80,6 +80,54 @@ async def test_double_response():
|
||||
# Success, terminate pending tasks.
|
||||
return
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_streams():
|
||||
# Node A should be able to open a stream with node B and then vice versa.
|
||||
# Stream IDs should be generated uniquely so that the stream state is not overwritten
|
||||
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8004"])
|
||||
node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8005"])
|
||||
|
||||
async def stream_handler_a(stream):
|
||||
while True:
|
||||
read_string = (await stream.read()).decode()
|
||||
|
||||
response = "ack_a:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
async def stream_handler_b(stream):
|
||||
while True:
|
||||
read_string = (await stream.read()).decode()
|
||||
|
||||
response = "ack_b:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a)
|
||||
node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
|
||||
stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"])
|
||||
stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a_message = message + "_a"
|
||||
b_message = message + "_b"
|
||||
|
||||
await stream_a.write(a_message.encode())
|
||||
await stream_b.write(b_message.encode())
|
||||
|
||||
response_a = (await stream_a.read()).decode()
|
||||
response_b = (await stream_b.read()).decode()
|
||||
|
||||
assert response_a == ("ack_b:" + a_message) and response_b == ("ack_a:" + b_message)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
return
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_connect():
|
||||
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8001/"])
|
||||
|
@ -1,9 +1,6 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
# from network.connection.raw_connection import RawConnection
|
||||
|
||||
|
||||
async def handle_echo(reader, writer):
|
||||
data = await reader.read(100)
|
||||
writer.write(data)
|
||||
@ -20,7 +17,6 @@ async def test_simple_echo():
|
||||
await asyncio.start_server(handle_echo, server_ip, server_port)
|
||||
|
||||
reader, writer = await asyncio.open_connection(server_ip, server_port)
|
||||
# raw_connection = RawConnection(server_ip, server_port, reader, writer)
|
||||
|
||||
test_message = "hello world"
|
||||
writer.write(test_message.encode())
|
||||
|
@ -69,7 +69,7 @@ class TCP(ITransport):
|
||||
|
||||
reader, writer = await asyncio.open_connection(host, port)
|
||||
|
||||
return RawConnection(host, port, reader, writer)
|
||||
return RawConnection(host, port, reader, writer, True)
|
||||
|
||||
def create_listener(self, handler_function, options=None):
|
||||
"""
|
||||
|
@ -17,11 +17,11 @@ class TransportUpgrader():
|
||||
def upgrade_security(self):
|
||||
pass
|
||||
|
||||
def upgrade_connection(self, conn, initiator):
|
||||
def upgrade_connection(self, conn):
|
||||
"""
|
||||
upgrade raw connection to muxed connection
|
||||
"""
|
||||
|
||||
# For PoC, no security, default to mplex
|
||||
# TODO do exchange to determine multiplexer
|
||||
return Mplex(conn, initiator)
|
||||
return Mplex(conn)
|
||||
|
Loading…
x
Reference in New Issue
Block a user