diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 9fd77aa..39e7d36 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -75,7 +75,7 @@ async def run(port, destination): # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. - stream = await host.new_stream(peer_id, PROTOCOL_ID) + stream = await host.new_stream(peer_id, [PROTOCOL_ID]) asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(write_data(stream)) diff --git a/host/basic_host.py b/host/basic_host.py index bcda2fe..9734129 100644 --- a/host/basic_host.py +++ b/host/basic_host.py @@ -48,13 +48,11 @@ class BasicHost(IHost): # protocol_id can be a list of protocol_ids # stream will decide which protocol_id to run on - async def new_stream(self, peer_id, protocol_id): + async def new_stream(self, peer_id, protocol_ids): """ :param peer_id: peer_id that host is connecting :param protocol_id: protocol id that stream runs on :return: true if successful """ - # TODO: host should return a mux stream not a raw stream - stream = await self.network.new_stream(peer_id, protocol_id) - stream.set_protocol(protocol_id) + stream = await self.network.new_stream(peer_id, protocol_ids) return stream diff --git a/host/host_interface.py b/host/host_interface.py index 8a26b4b..a65f8a7 100644 --- a/host/host_interface.py +++ b/host/host_interface.py @@ -33,9 +33,9 @@ class IHost(ABC): # protocol_id can be a list of protocol_ids # stream will decide which protocol_id to run on @abstractmethod - def new_stream(self, peer_id, protocol_id): + def new_stream(self, peer_id, protocol_ids): """ :param peer_id: peer_id that host is connecting - :param proto_id: protocol id that stream runs on + :param protocol_ids: protocol ids that stream can run on :return: true if successful """ diff --git a/network/network_interface.py b/network/network_interface.py index 189814f..d690beb 100644 --- a/network/network_interface.py +++ b/network/network_interface.py @@ -18,11 +18,11 @@ class INetwork(ABC): """ @abstractmethod - def new_stream(self, peer_id, protocol_id): + def new_stream(self, peer_id, protocol_ids): """ :param peer_id: peer_id of destination - :param protocol_id: protocol id - :return: stream instance + :param protocol_ids: available protocol ids to use for stream + :return: net stream instance """ @abstractmethod diff --git a/network/swarm.py b/network/swarm.py index 9ffc766..57d4302 100644 --- a/network/swarm.py +++ b/network/swarm.py @@ -1,9 +1,12 @@ from peer.id import ID +from protocol_muxer.multiselect_client import MultiselectClient +from protocol_muxer.multiselect import Multiselect from .network_interface import INetwork from .stream.net_stream import NetStream from .multiaddr import MultiAddr from .connection.raw_connection import RawConnection + class Swarm(INetwork): # pylint: disable=too-many-instance-attributes, cell-var-from-loop @@ -17,6 +20,10 @@ class Swarm(INetwork): self.stream_handlers = dict() self.transport = None + # Protocol muxing + self.multiselect = Multiselect() + self.multiselect_client = MultiselectClient() + def get_peer_id(self): return self.self_id @@ -26,9 +33,10 @@ class Swarm(INetwork): :param stream_handler: a stream handler instance :return: true if successful """ - self.stream_handlers[protocol_id] = stream_handler + self.multiselect.add_handler(protocol_id, stream_handler) + return True - async def new_stream(self, peer_id, protocol_id): + async def new_stream(self, peer_id, protocol_ids): """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -58,10 +66,15 @@ class Swarm(INetwork): # Use muxed conn to open stream, which returns # a muxed stream - muxed_stream = await muxed_conn.open_stream(protocol_id, peer_id, multiaddr) + # TODO: Remove protocol id from being passed into muxed_conn + muxed_stream = await muxed_conn.open_stream(protocol_ids[0], peer_id, multiaddr) - # Create a net stream + # Perform protocol muxing to determine protocol to use + selected_protocol = await self.multiselect_client.select_one_of(protocol_ids, muxed_stream) + + # Create a net stream with the selected protocol net_stream = NetStream(muxed_stream) + net_stream.set_protocol(selected_protocol) return net_stream @@ -93,14 +106,20 @@ class Swarm(INetwork): multiaddr_dict['port'], reader, writer) muxed_conn = self.upgrader.upgrade_connection(raw_conn, False) - muxed_stream, _, protocol_id = await muxed_conn.accept_stream() + # TODO: Remove protocol id from muxed_conn accept stream or + # move protocol muxing into accept_stream + muxed_stream, _, _ = await muxed_conn.accept_stream() + + # Perform protocol muxing to determine protocol to use + selected_protocol, handler = await self.multiselect.negotiate(muxed_stream) + net_stream = NetStream(muxed_stream) - net_stream.set_protocol(protocol_id) + net_stream.set_protocol(selected_protocol) # Give to stream handler # TODO: handle case when stream handler is set # TODO: handle case of multiple protocols over same raw connection - await self.stream_handlers[protocol_id](net_stream) + await handler(net_stream) try: # Success diff --git a/protocol_muxer/__init__.py b/protocol_muxer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/protocol_muxer/multiselect.py b/protocol_muxer/multiselect.py new file mode 100644 index 0000000..b2f4532 --- /dev/null +++ b/protocol_muxer/multiselect.py @@ -0,0 +1,92 @@ +from .multiselect_muxer_interface import IMultiselectMuxer +from .multiselect_communicator import MultiselectCommunicator + +MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" +PROTOCOL_NOT_FOUND_MSG = "na" + +class Multiselect(IMultiselectMuxer): + """ + Multiselect module that is responsible for responding to + a multiselect client and deciding on + a specific protocol and handler pair to use for communication + """ + + def __init__(self): + self.handlers = {} + + def add_handler(self, protocol, handler): + """ + Store the handler with the given protocol + :param protocol: protocol name + :param handler: handler function + """ + self.handlers[protocol] = handler + + async def negotiate(self, stream): + """ + Negotiate performs protocol selection + :param stream: stream to negotiate on + :return: selected protocol name, handler function + :raise Exception: negotiation failed exception + """ + + # Create a communicator to handle all communication across the stream + communicator = MultiselectCommunicator(stream) + + # Perform handshake to ensure multiselect protocol IDs match + await self.handshake(communicator) + + # Read and respond to commands until a valid protocol ID is sent + while True: + # Read message + command = await communicator.read_stream_until_eof() + + # Command is ls or a protocol + if command == "ls": + # TODO: handle ls command + pass + else: + protocol = command + if protocol in self.handlers: + # Tell counterparty we have decided on a protocol + await communicator.write(protocol) + + # Return the decided on protocol + return protocol, self.handlers[protocol] + # Tell counterparty this protocol was not found + await communicator.write(PROTOCOL_NOT_FOUND_MSG) + + async def handshake(self, communicator): + """ + Perform handshake to agree on multiselect protocol + :param communicator: communicator to use + :raise Exception: error in handshake + """ + + # TODO: Use format used by go repo for messages + + # Send our MULTISELECT_PROTOCOL_ID to other party + await communicator.write(MULTISELECT_PROTOCOL_ID) + + # Read in the protocol ID from other party + handshake_contents = await communicator.read_stream_until_eof() + + # Confirm that the protocols are the same + if not self.validate_handshake(handshake_contents): + raise MultiselectError("multiselect protocol ID mismatch") + + # Handshake succeeded if this point is reached + + def validate_handshake(self, handshake_contents): + """ + Determine if handshake is valid and should be confirmed + :param handshake_contents: contents of handshake message + :return: true if handshake is complete, false otherwise + """ + + # TODO: Modify this when format used by go repo for messages + # is added + return handshake_contents == MULTISELECT_PROTOCOL_ID + +class MultiselectError(ValueError): + """Raised when an error occurs in multiselect process""" diff --git a/protocol_muxer/multiselect_client.py b/protocol_muxer/multiselect_client.py new file mode 100644 index 0000000..c455877 --- /dev/null +++ b/protocol_muxer/multiselect_client.py @@ -0,0 +1,121 @@ +from .multiselect_client_interface import IMultiselectClient +from .multiselect_communicator import MultiselectCommunicator + +MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" +PROTOCOL_NOT_FOUND_MSG = "na" + +class MultiselectClient(IMultiselectClient): + """ + Client for communicating with receiver's multiselect + module in order to select a protocol id to communicate over + """ + + def __init__(self): + pass + + async def handshake(self, communicator): + """ + Ensure that the client and multiselect + are both using the same multiselect protocol + :param stream: stream to communicate with multiselect over + :raise Exception: multiselect protocol ID mismatch + """ + + # TODO: Use format used by go repo for messages + + # Send our MULTISELECT_PROTOCOL_ID to counterparty + await communicator.write(MULTISELECT_PROTOCOL_ID) + + # Read in the protocol ID from other party + handshake_contents = await communicator.read_stream_until_eof() + + # Confirm that the protocols are the same + if not self.validate_handshake(handshake_contents): + raise MultiselectClientError("multiselect protocol ID mismatch") + + # Handshake succeeded if this point is reached + + def validate_handshake(self, handshake_contents): + """ + Determine if handshake is valid and should be confirmed + :param handshake_contents: contents of handshake message + :return: true if handshake is complete, false otherwise + """ + + # TODO: Modify this when format used by go repo for messages + # is added + return handshake_contents == MULTISELECT_PROTOCOL_ID + + async def select_protocol_or_fail(self, protocol, stream): + """ + Send message to multiselect selecting protocol + and fail if multiselect does not return same protocol + :param protocol: protocol to select + :param stream: stream to communicate with multiselect over + :return: selected protocol + """ + + # Create a communicator to handle all communication across the stream + communicator = MultiselectCommunicator(stream) + + # Perform handshake to ensure multiselect protocol IDs match + await self.handshake(communicator) + + # Try to select the given protocol + selected_protocol = await self.try_select(communicator, protocol) + + return selected_protocol + + async def select_one_of(self, protocols, stream): + """ + For each protocol, send message to multiselect selecting protocol + and fail if multiselect does not return same protocol. Returns first + protocol that multiselect agrees on (i.e. that multiselect selects) + :param protocol: protocol to select + :param stream: stream to communicate with multiselect over + :return: selected protocol + """ + + # Create a communicator to handle all communication across the stream + communicator = MultiselectCommunicator(stream) + + # Perform handshake to ensure multiselect protocol IDs match + await self.handshake(communicator) + + # For each protocol, attempt to select that protocol + # and return the first protocol selected + for protocol in protocols: + try: + selected_protocol = await self.try_select(communicator, protocol) + return selected_protocol + except MultiselectClientError: + pass + + # No protocols were found, so return no protocols supported error + raise MultiselectClientError("protocols not supported") + + async def try_select(self, communicator, protocol): + """ + Try to select the given protocol or raise exception if fails + :param communicator: communicator to use to communicate with counterparty + :param protocol: protocol to select + :raise Exception: error in protocol selection + :return: selected protocol + """ + + # Tell counterparty we want to use protocol + await communicator.write(protocol) + + # Get what counterparty says in response + response = await communicator.read_stream_until_eof() + + # Return protocol if response is equal to protocol or raise error + if response == protocol: + return protocol + if response == PROTOCOL_NOT_FOUND_MSG: + raise MultiselectClientError("protocol not supported") + else: + raise MultiselectClientError("unrecognized response: " + response) + +class MultiselectClientError(ValueError): + """Raised when an error occurs in protocol selection process""" diff --git a/protocol_muxer/multiselect_client_interface.py b/protocol_muxer/multiselect_client_interface.py new file mode 100644 index 0000000..1c8d066 --- /dev/null +++ b/protocol_muxer/multiselect_client_interface.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +class IMultiselectClient(ABC): + """ + Client for communicating with receiver's multiselect + module in order to select a protocol id to communicate over + """ + + @abstractmethod + def select_protocol_or_fail(self, protocol, stream): + """ + Send message to multiselect selecting protocol + and fail if multiselect does not return same protocol + :param protocol: protocol to select + :param stream: stream to communicate with multiselect over + :return: selected protocol + """ + + @abstractmethod + def select_one_of(self, protocols, stream): + """ + For each protocol, send message to multiselect selecting protocol + and fail if multiselect does not return same protocol. Returns first + protocol that multiselect agrees on (i.e. that multiselect selects) + :param protocol: protocol to select + :param stream: stream to communicate with multiselect over + :return: selected protocol + """ diff --git a/protocol_muxer/multiselect_communicator.py b/protocol_muxer/multiselect_communicator.py new file mode 100644 index 0000000..bfdeabf --- /dev/null +++ b/protocol_muxer/multiselect_communicator.py @@ -0,0 +1,25 @@ +from .multiselect_communicator_interface import IMultiselectCommunicator + +class MultiselectCommunicator(IMultiselectCommunicator): + """ + Communicator helper class that ensures both the client + and multistream module will follow the same multistream protocol, + which is necessary for them to work + """ + + def __init__(self, stream): + self.stream = stream + + async def write(self, msg_str): + """ + Write message to stream + :param msg_str: message to write + """ + await self.stream.write(msg_str.encode()) + + async def read_stream_until_eof(self): + """ + Reads message from stream until EOF + """ + read_str = (await self.stream.read()).decode() + return read_str diff --git a/protocol_muxer/multiselect_communicator_interface.py b/protocol_muxer/multiselect_communicator_interface.py new file mode 100644 index 0000000..a9990fd --- /dev/null +++ b/protocol_muxer/multiselect_communicator_interface.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + +class IMultiselectCommunicator(ABC): + """ + Communicator helper class that ensures both the client + and multistream module will follow the same multistream protocol, + which is necessary for them to work + """ + + @abstractmethod + def write(self, msg_str): + """ + Write message to stream + :param msg_str: message to write + """ + + @abstractmethod + def read_stream_until_eof(self): + """ + Reads message from stream until EOF + """ diff --git a/protocol_muxer/multiselect_muxer_interface.py b/protocol_muxer/multiselect_muxer_interface.py new file mode 100644 index 0000000..0f08fec --- /dev/null +++ b/protocol_muxer/multiselect_muxer_interface.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +class IMultiselectMuxer(ABC): + """ + Multiselect module that is responsible for responding to + a multiselect client and deciding on + a specific protocol and handler pair to use for communication + """ + + @abstractmethod + def add_handler(self, protocol, handler): + """ + Store the handler with the given protocol + :param protocol: protocol name + :param handler: handler function + """ + + @abstractmethod + def negotiate(self, stream): + """ + Negotiate performs protocol selection + :param stream: stream to negotiate on + :return: selected protocol name, handler function + :raise Exception: negotiation failed exception + """ diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index da5b6ff..25c0409 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -20,8 +20,10 @@ async def test_simple_messages(): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addr("node_b", "/ip4/127.0.0.1/tcp/8000", 10) + print("node_a about to open stream") - stream = await node_a.new_stream("node_b", "/echo/1.0.0") + stream = await node_a.new_stream("node_b", ["/echo/1.0.0"]) + messages = ["hello" + str(x) for x in range(10)] for message in messages: await stream.write(message.encode()) @@ -57,7 +59,7 @@ async def test_double_response(): # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addr("node_b", "/ip4/127.0.0.1/tcp/8003", 10) print("node_a about to open stream") - stream = await node_a.new_stream("node_b", "/echo/1.0.0") + stream = await node_a.new_stream("node_b", ["/echo/1.0.0"]) messages = ["hello" + str(x) for x in range(10)] for message in messages: await stream.write(message.encode()) diff --git a/tests/protocol_muxer/__init__.py b/tests/protocol_muxer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py new file mode 100644 index 0000000..ddcbfec --- /dev/null +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -0,0 +1,88 @@ +import pytest + +from libp2p.libp2p import new_node +from protocol_muxer.multiselect_client import MultiselectClientError + +# TODO: Add tests for multiple streams being opened on different +# protocols through the same connection + +# Note: async issues occurred when using the same port +# so that's why I use different ports here. +# TODO: modify tests so that those async issues don't occur +# when using the same ports across tests + +async def perform_simple_test(expected_selected_protocol, \ + protocols_for_client, protocols_with_handlers, \ + node_a_port, node_b_port): + transport_opt_a = ["/ip4/127.0.0.1/tcp/" + str(node_a_port) + "/ipfs/node_a"] + transport_opt_b = ["/ip4/127.0.0.1/tcp/" + str(node_b_port) + "/ipfs/node_b"] + node_a = await new_node(\ + transport_opt=transport_opt_a) + node_b = await new_node(\ + transport_opt=transport_opt_b) + + async def stream_handler(stream): + while True: + read_string = (await stream.read()).decode() + print("host B received:" + read_string) + + response = "ack:" + read_string + print("sending response:" + response) + await stream.write(response.encode()) + + for protocol in protocols_with_handlers: + node_b.set_stream_handler(protocol, stream_handler) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + node_a.get_peerstore().add_addr("node_b", "/ip4/127.0.0.1/tcp/" + str(node_b_port), 10) + + stream = await node_a.new_stream("node_b", protocols_for_client) + messages = ["hello" + str(x) for x in range(10)] + for message in messages: + await stream.write(message.encode()) + + response = (await stream.read()).decode() + + print("res: " + response) + assert response == ("ack:" + message) + + assert expected_selected_protocol == stream.get_protocol() + + # Success, terminate pending tasks. + return + +@pytest.mark.asyncio +async def test_single_protocol_succeeds(): + expected_selected_protocol = "/echo/1.0.0" + await perform_simple_test(expected_selected_protocol, \ + ["/echo/1.0.0"], ["/echo/1.0.0"], 8050, 8051) + +@pytest.mark.asyncio +async def test_single_protocol_fails(): + with pytest.raises(MultiselectClientError): + await perform_simple_test("", ["/echo/1.0.0"], \ + ["/potato/1.0.0"], 8052, 8053) + +@pytest.mark.asyncio +async def test_multiple_protocol_first_is_valid_succeeds(): + expected_selected_protocol = "/echo/1.0.0" + protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"] + protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] + await perform_simple_test(expected_selected_protocol, protocols_for_client, \ + protocols_for_listener, 8054, 8055) + +@pytest.mark.asyncio +async def test_multiple_protocol_second_is_valid_succeeds(): + expected_selected_protocol = "/foo/1.0.0" + protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"] + protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"] + await perform_simple_test(expected_selected_protocol, protocols_for_client, \ + protocols_for_listener, 8056, 8057) + +@pytest.mark.asyncio +async def test_multiple_protocol_fails(): + protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"] + protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"] + with pytest.raises(MultiselectClientError): + await perform_simple_test("", protocols_for_client, \ + protocols_for_listener, 8058, 8059)