Merge pull request #91 from zixuanzh/next-stream-id

Refactoring stream IDs and notion of initiator
This commit is contained in:
Robert Zajac 2018-12-01 12:43:53 -05:00 committed by GitHub
commit f54ac8aaa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 72 additions and 27 deletions

View File

@ -2,13 +2,24 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection):
# pylint: disable=too-few-public-methods
def __init__(self, ip, port, reader, writer):
def __init__(self, ip, port, reader, writer, initiator):
# pylint: disable=too-many-arguments
self.conn_ip = ip
self.conn_port = port
self.reader = reader
self.writer = writer
self._next_id = 0 if initiator else 1
self.initiator = initiator
def close(self):
self.writer.close()
def next_stream_id(self):
"""
Get next available stream id
:return: next available stream id for the connection
"""
next_id = self._next_id
self._next_id += 2
return next_id

View File

@ -59,7 +59,7 @@ class Swarm(INetwork):
raw_conn = await self.transport.dial(multiaddr)
# Use upgrader to upgrade raw conn to muxed conn
muxed_conn = self.upgrader.upgrade_connection(raw_conn, True)
muxed_conn = self.upgrader.upgrade_connection(raw_conn)
# Store muxed connection in connections
self.connections[peer_id] = muxed_conn
@ -118,8 +118,8 @@ class Swarm(INetwork):
# Upgrade reader/write to a net_stream and pass \
# to appropriate stream handler (using multiaddr)
raw_conn = RawConnection(multiaddr.value_for_protocol('ip4'),
multiaddr.value_for_protocol('tcp'), reader, writer)
muxed_conn = self.upgrader.upgrade_connection(raw_conn, False)
multiaddr.value_for_protocol('tcp'), reader, writer, False)
muxed_conn = self.upgrader.upgrade_connection(raw_conn)
# TODO: Remove protocol id from muxed_conn accept stream or
# move protocol muxing into accept_stream

View File

@ -10,36 +10,26 @@ class Mplex(IMuxedConn):
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
"""
def __init__(self, conn, initiator):
def __init__(self, conn):
"""
create a new muxed connection
:param conn: an instance of raw connection
:param initiator: boolean to prevent multiplex with self
"""
self.raw_conn = conn
self.initiator = initiator
self.initiator = conn.initiator
# Mapping from stream ID -> buffer of messages for that stream
self.buffers = {}
self.stream_queue = asyncio.Queue()
self._next_id = 0
self.data_buffer = bytearray()
# The initiator need not read upon construction time.
# The initiator of the raw connection need not read upon construction time.
# It should read when the user decides that it wants to read from the constructed stream.
if not initiator:
if not self.initiator:
asyncio.ensure_future(self.handle_incoming())
def _next_stream_id(self):
"""
Get next available stream id
:return: next available stream id for the connection
"""
next_id = self._next_id
self._next_id += 1
return next_id
def close(self):
"""
close the stream muxer and underlying raw connection
@ -88,7 +78,7 @@ class Mplex(IMuxedConn):
:param multi_addr: multi_addr that stream connects to
:return: a new stream
"""
stream_id = self._next_stream_id()
stream_id = self.raw_conn.next_stream_id()
stream = MplexStream(stream_id, multi_addr, self)
self.buffers[stream_id] = asyncio.Queue()
return stream

View File

@ -80,6 +80,54 @@ async def test_double_response():
# Success, terminate pending tasks.
return
@pytest.mark.asyncio
async def test_multiple_streams():
# Node A should be able to open a stream with node B and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not overwritten
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8004"])
node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8005"])
async def stream_handler_a(stream):
while True:
read_string = (await stream.read()).decode()
response = "ack_a:" + read_string
await stream.write(response.encode())
async def stream_handler_b(stream):
while True:
read_string = (await stream.read()).decode()
response = "ack_b:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a)
node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"])
stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
await stream_a.write(a_message.encode())
await stream_b.write(b_message.encode())
response_a = (await stream_a.read()).decode()
response_b = (await stream_b.read()).decode()
assert response_a == ("ack_b:" + a_message) and response_b == ("ack_a:" + b_message)
# Success, terminate pending tasks.
return
@pytest.mark.asyncio
async def test_host_connect():
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/8001/"])

View File

@ -1,9 +1,6 @@
import asyncio
import pytest
# from network.connection.raw_connection import RawConnection
async def handle_echo(reader, writer):
data = await reader.read(100)
writer.write(data)
@ -20,7 +17,6 @@ async def test_simple_echo():
await asyncio.start_server(handle_echo, server_ip, server_port)
reader, writer = await asyncio.open_connection(server_ip, server_port)
# raw_connection = RawConnection(server_ip, server_port, reader, writer)
test_message = "hello world"
writer.write(test_message.encode())

View File

@ -69,7 +69,7 @@ class TCP(ITransport):
reader, writer = await asyncio.open_connection(host, port)
return RawConnection(host, port, reader, writer)
return RawConnection(host, port, reader, writer, True)
def create_listener(self, handler_function, options=None):
"""

View File

@ -17,11 +17,11 @@ class TransportUpgrader():
def upgrade_security(self):
pass
def upgrade_connection(self, conn, initiator):
def upgrade_connection(self, conn):
"""
upgrade raw connection to muxed connection
"""
# For PoC, no security, default to mplex
# TODO do exchange to determine multiplexer
return Mplex(conn, initiator)
return Mplex(conn)