From 5307c0506ba7b4cf1f32482d7c765043d54a4a2d Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 15 Sep 2019 21:41:29 +0800 Subject: [PATCH 01/14] Change `IMuxedConn` to `INetConn` in `Notifee` --- libp2p/network/connection/swarm_connection.py | 2 +- libp2p/network/notifee_interface.py | 6 ++-- libp2p/network/stream/net_stream.py | 6 ++-- libp2p/network/stream/net_stream_interface.py | 2 +- libp2p/network/swarm.py | 3 +- libp2p/pubsub/pubsub.py | 2 +- libp2p/pubsub/pubsub_notifee.py | 8 +++--- libp2p/stream_muxer/abc.py | 2 +- libp2p/stream_muxer/mplex/mplex_stream.py | 28 +++++++++---------- tests/libp2p/test_notify.py | 22 +++++++-------- tests/pubsub/test_pubsub.py | 2 +- 11 files changed, 40 insertions(+), 43 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 15816fc..10b83a1 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -77,7 +77,7 @@ class SwarmConn(INetConn): async def _notify_disconnected(self) -> None: for notifee in self.swarm.notifees: - await notifee.disconnected(self.swarm, self.conn) + await notifee.disconnected(self.swarm, self) async def start(self) -> None: await self.run_task(self._handle_new_streams()) diff --git a/libp2p/network/notifee_interface.py b/libp2p/network/notifee_interface.py index ef996bf..c31f473 100644 --- a/libp2p/network/notifee_interface.py +++ b/libp2p/network/notifee_interface.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: from .network_interface import INetwork # noqa: F401 @@ -26,14 +26,14 @@ class INotifee(ABC): """ @abstractmethod - async def connected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def connected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was opened on :param conn: connection that was opened """ @abstractmethod - async def disconnected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def disconnected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was closed on :param conn: connection that was closed diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index d500c08..3ae7c9c 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,4 +1,4 @@ -from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamEOF, @@ -16,13 +16,11 @@ from .net_stream_interface import INetStream class NetStream(INetStream): muxed_stream: IMuxedStream - # TODO: Why we expose `mplex_conn` here? - mplex_conn: IMuxedConn protocol_id: TProtocol def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream - self.mplex_conn = muxed_stream.mplex_conn + self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None def get_protocol(self) -> TProtocol: diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index d054789..41bf423 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -7,7 +7,7 @@ from libp2p.typing import TProtocol class INetStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod def get_protocol(self) -> TProtocol: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 272a8a9..a0d9d77 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -278,8 +278,7 @@ class Swarm(INetwork): self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred for notifee in self.notifees: - # TODO: Call with other type of conn? - await notifee.connected(self, muxed_conn) + await notifee.connected(self, swarm_conn) await swarm_conn.start() return swarm_conn diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b162b89..17cf034 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -151,7 +151,7 @@ class Pubsub: messages from other nodes :param stream: stream to continously read from """ - peer_id = stream.mplex_conn.peer_id + peer_id = stream.muxed_conn.peer_id while True: incoming: bytes = await read_varint_prefixed_bytes(stream) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6ecab1a..627152e 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.network_interface import INetwork from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: import asyncio # noqa: F401 @@ -29,16 +29,16 @@ class PubsubNotifee(INotifee): async def closed_stream(self, network: INetwork, stream: INetStream) -> None: pass - async def connected(self, network: INetwork, conn: IMuxedConn) -> None: + async def connected(self, network: INetwork, conn: INetConn) -> None: """ Add peer_id to initiator_peers_queue, so that this peer_id can be used to create a stream and we only want to have one pubsub stream with each peer. :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.peer_id) + await self.initiator_peers_queue.put(conn.conn.peer_id) - async def disconnected(self, network: INetwork, conn: IMuxedConn) -> None: + async def disconnected(self, network: INetwork, conn: INetConn) -> None: pass async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 78438f2..4af110b 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -55,7 +55,7 @@ class IMuxedConn(ABC): class IMuxedStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod async def reset(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 8cabccc..06b90fa 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -18,7 +18,7 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - mplex_conn: "Mplex" + muxed_conn: "Mplex" read_deadline: int write_deadline: int @@ -32,15 +32,15 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: + def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream - :param mplex_conn: muxed connection of this muxed_stream + :param muxed_conn: muxed connection of this muxed_stream """ self.name = name self.stream_id = stream_id - self.mplex_conn = mplex_conn + self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None self.event_local_closed = asyncio.Event() @@ -147,7 +147,7 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.MessageReceiver ) - return await self.mplex_conn.send_message(flag, data, self.stream_id) + return await self.muxed_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ @@ -163,8 +163,8 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. - await self.mplex_conn.send_message(flag, None, self.stream_id) + # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. + await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool async with self.close_lock: @@ -173,8 +173,8 @@ class MplexStream(IMuxedStream): if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. - async with self.mplex_conn.streams_lock: - del self.mplex_conn.streams[self.stream_id] + async with self.muxed_conn.streams_lock: + del self.muxed_conn.streams[self.stream_id] async def reset(self) -> None: """ @@ -196,19 +196,19 @@ class MplexStream(IMuxedStream): else HeaderTags.ResetReceiver ) asyncio.ensure_future( - self.mplex_conn.send_message(flag, None, self.stream_id) + self.muxed_conn.send_message(flag, None, self.stream_id) ) await asyncio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() - async with self.mplex_conn.streams_lock: + async with self.muxed_conn.streams_lock: if ( - self.mplex_conn.streams is not None - and self.stream_id in self.mplex_conn.streams + self.muxed_conn.streams is not None + and self.stream_id in self.muxed_conn.streams ): - del self.mplex_conn.streams[self.stream_id] + del self.muxed_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index b9a8707..e47a044 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -34,7 +34,7 @@ class MyNotifee(INotifee): pass async def connected(self, network, conn): - self.events.append(["connected" + self.val_to_append_to_event, conn]) + self.events.append(["connected" + self.val_to_append_to_event, conn.conn]) async def disconnected(self, network, conn): pass @@ -79,7 +79,7 @@ async def test_one_notifier(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] + assert events == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] messages = ["hello", "hello"] for message in messages: @@ -103,7 +103,7 @@ async def test_one_notifier_on_two_nodes(): # and that the stream passed into opened_stream matches the stream created on # node_b assert events_b == [ - ["connectedb", stream.mplex_conn], + ["connectedb", stream.muxed_conn], ["opened_streamb", stream], ] for message in messages: @@ -126,7 +126,7 @@ async def test_one_notifier_on_two_nodes(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] + assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] for message in messages: expected_resp = ACK + message @@ -164,7 +164,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): # node_b assert events_b == [ ["listenedb", node_b_multiaddr], - ["connectedb", stream.mplex_conn], + ["connectedb", stream.muxed_conn], ["opened_streamb", stream], ] for message in messages: @@ -190,7 +190,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] + assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] for message in messages: expected_resp = ACK + message @@ -219,8 +219,8 @@ async def test_two_notifiers(): # Ensure the connected and opened_stream events were hit in both Notifee objs # and that the stream passed into opened_stream matches the stream created on # node_a - assert events0 == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] - assert events1 == [["connected1", stream.mplex_conn], ["opened_stream1", stream]] + assert events0 == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] + assert events1 == [["connected1", stream.muxed_conn], ["opened_stream1", stream]] messages = ["hello", "hello"] for message in messages: @@ -253,7 +253,7 @@ async def test_ten_notifiers(): # node_a for i in range(num_notifiers): assert events_lst[i] == [ - ["connected" + str(i), stream.mplex_conn], + ["connected" + str(i), stream.muxed_conn], ["opened_stream" + str(i), stream], ] @@ -280,7 +280,7 @@ async def test_ten_notifiers_on_two_nodes(): # node_b for i in range(num_notifiers): assert events_lst_b[i] == [ - ["connectedb" + str(i), stream.mplex_conn], + ["connectedb" + str(i), stream.muxed_conn], ["opened_streamb" + str(i), stream], ] while True: @@ -306,7 +306,7 @@ async def test_ten_notifiers_on_two_nodes(): # node_a for i in range(num_notifiers): assert events_lst_a[i] == [ - ["connecteda" + str(i), stream.mplex_conn], + ["connecteda" + str(i), stream.muxed_conn], ["opened_streama" + str(i), stream], ] diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 3413949..29fdf36 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -233,7 +233,7 @@ class FakeNetStream: class FakeMplexConn(NamedTuple): peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - mplex_conn = FakeMplexConn() + muxed_conn = FakeMplexConn() def __init__(self) -> None: self._queue = asyncio.Queue() From 675c61ce3b527ba27c74c837950298a4f136f673 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 15 Sep 2019 21:45:44 +0800 Subject: [PATCH 02/14] Move test_notify from libp2p to network --- tests/{libp2p => network}/test_notify.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{libp2p => network}/test_notify.py (100%) diff --git a/tests/libp2p/test_notify.py b/tests/network/test_notify.py similarity index 100% rename from tests/libp2p/test_notify.py rename to tests/network/test_notify.py From b8b5ac5e06026fed267ac6972cd572098a04b833 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 17 Sep 2019 21:54:20 +0800 Subject: [PATCH 03/14] Add test for notifee disconnected --- libp2p/host/basic_host.py | 2 +- libp2p/network/network_interface.py | 8 +- libp2p/network/swarm.py | 15 +- libp2p/pubsub/pubsub.py | 2 +- tests/network/test_notify.py | 393 ++++++---------------------- 5 files changed, 87 insertions(+), 333 deletions(-) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 862fd5c..5f0ccfd 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -100,7 +100,7 @@ class BasicHost(IHost): :return: stream: new stream created """ - net_stream = await self._network.new_stream(peer_id, protocol_ids) + net_stream = await self._network.new_stream(peer_id) # Perform protocol muxing to determine protocol to use selected_protocol = await self.multiselect_client.select_one_of( diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 470da1a..94ddba2 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -7,7 +7,7 @@ from libp2p.network.connection.net_connection_interface import INetConn from libp2p.peer.id import ID from libp2p.peer.peerstore_interface import IPeerStore from libp2p.transport.listener_interface import IListener -from libp2p.typing import StreamHandlerFn, TProtocol +from libp2p.typing import StreamHandlerFn from .stream.net_stream_interface import INetStream @@ -38,9 +38,7 @@ class INetwork(ABC): """ @abstractmethod - async def new_stream( - self, peer_id: ID, protocol_ids: Sequence[TProtocol] - ) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_ids: available protocol ids to use for stream @@ -61,7 +59,7 @@ class INetwork(ABC): """ @abstractmethod - def notify(self, notifee: "INotifee") -> bool: + def register_notifee(self, notifee: "INotifee") -> None: """ :param notifee: object implementing Notifee interface :return: true if notifee registered successfully, false otherwise diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index a0d9d77..18218d6 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional from multiaddr import Multiaddr @@ -14,7 +14,7 @@ from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFail from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader -from libp2p.typing import StreamHandlerFn, TProtocol +from libp2p.typing import StreamHandlerFn from .connection.raw_connection import RawConnection from .connection.swarm_connection import SwarmConn @@ -131,9 +131,7 @@ class Swarm(INetwork): return swarm_conn - async def new_stream( - self, peer_id: ID, protocol_ids: Sequence[TProtocol] - ) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -229,15 +227,12 @@ class Swarm(INetwork): # No maddr succeeded return False - def notify(self, notifee: INotifee) -> bool: + def register_notifee(self, notifee: INotifee) -> None: """ :param notifee: object implementing Notifee interface :return: true if notifee registered successfully, false otherwise """ - if isinstance(notifee, INotifee): - self.notifees.append(notifee) - return True - return False + self.notifees.append(notifee) def add_router(self, router: IPeerRouting) -> None: self.router = router diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 17cf034..a295b69 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -95,7 +95,7 @@ class Pubsub: # Register a notifee self.peer_queue = asyncio.Queue() - self.host.get_network().notify(PubsubNotifee(self.peer_queue)) + self.host.get_network().register_notifee(PubsubNotifee(self.peer_queue)) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols diff --git a/tests/network/test_notify.py b/tests/network/test_notify.py index e47a044..aaf0ed5 100644 --- a/tests/network/test_notify.py +++ b/tests/network/test_notify.py @@ -5,347 +5,108 @@ called, and that the stream passed into opened_stream is correct Note: Listen event does not get hit because MyNotifee is passed into network after network has already started listening -TODO: Add tests for closed_stream disconnected, listen_close when those +TODO: Add tests for closed_stream, listen_close when those features are implemented in swarm """ -import multiaddr +import asyncio +import enum + import pytest -from libp2p import initialize_default_swarm, new_node -from libp2p.crypto.rsa import create_new_key_pair -from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee -from tests.constants import MAX_READ_LEN -from tests.utils import perform_two_host_set_up +from tests.configs import LISTEN_MADDR +from tests.factories import SwarmFactory +from tests.utils import connect_swarm -ACK = "ack:" + +class Event(enum.Enum): + OpenedStream = 0 + ClosedStream = 1 # Not implemented + Connected = 2 + Disconnected = 3 + Listen = 4 + ListenClose = 5 # Not implemented class MyNotifee(INotifee): - def __init__(self, events, val_to_append_to_event): + def __init__(self, events): self.events = events - self.val_to_append_to_event = val_to_append_to_event async def opened_stream(self, network, stream): - self.events.append(["opened_stream" + self.val_to_append_to_event, stream]) + self.events.append(Event.OpenedStream) async def closed_stream(self, network, stream): + # TODO: It is not implemented yet. pass async def connected(self, network, conn): - self.events.append(["connected" + self.val_to_append_to_event, conn.conn]) + self.events.append(Event.Connected) async def disconnected(self, network, conn): - pass + self.events.append(Event.Disconnected) async def listen(self, network, _multiaddr): - self.events.append(["listened" + self.val_to_append_to_event, _multiaddr]) + self.events.append(Event.Listen) async def listen_close(self, network, _multiaddr): + # TODO: It is not implemented yet. pass -class InvalidNotifee: - def __init__(self): - pass - - async def opened_stream(self): - assert False - - async def closed_stream(self): - assert False - - async def connected(self): - assert False - - async def disconnected(self): - assert False - - async def listen(self): - assert False - - @pytest.mark.asyncio -async def test_one_notifier(): - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events = [] - assert node_a.get_network().notify(MyNotifee(events, "0")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_one_notifier_on_two_nodes(): - events_b = [] - messages = ["hello", "hello"] - - async def my_stream_handler(stream): - # Ensure the connected and opened_stream events were hit in Notifee obj - # and that the stream passed into opened_stream matches the stream created on - # node_b - assert events_b == [ - ["connectedb", stream.muxed_conn], - ["opened_streamb", stream], - ] - for message in messages: - read_string = (await stream.read(len(message))).decode() - - resp = ACK + read_string - await stream.write(resp.encode()) - - node_a, node_b = await perform_two_host_set_up(my_stream_handler) - - # Add notifee for node_a - events_a = [] - assert node_a.get_network().notify(MyNotifee(events_a, "a")) - - # Add notifee for node_b - assert node_b.get_network().notify(MyNotifee(events_b, "b")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] - - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_one_notifier_on_two_nodes_with_listen(): - events_b = [] - messages = ["hello", "hello"] - - node_a_key_pair = create_new_key_pair() - node_a_transport_opt = ["/ip4/127.0.0.1/tcp/0"] - node_a = await new_node(node_a_key_pair, transport_opt=node_a_transport_opt) - await node_a.get_network().listen(multiaddr.Multiaddr(node_a_transport_opt[0])) - - # Set up node_b swarm to pass into host - node_b_key_pair = create_new_key_pair() - node_b_transport_opt = ["/ip4/127.0.0.1/tcp/0"] - node_b_multiaddr = multiaddr.Multiaddr(node_b_transport_opt[0]) - node_b_swarm = initialize_default_swarm( - node_b_key_pair, transport_opt=node_b_transport_opt - ) - node_b = BasicHost(node_b_swarm) - - async def my_stream_handler(stream): - # Ensure the listened, connected and opened_stream events were hit in Notifee obj - # and that the stream passed into opened_stream matches the stream created on - # node_b - assert events_b == [ - ["listenedb", node_b_multiaddr], - ["connectedb", stream.muxed_conn], - ["opened_streamb", stream], - ] - for message in messages: - read_string = (await stream.read(len(message))).decode() - resp = ACK + read_string - await stream.write(resp.encode()) - - # Add notifee for node_a - events_a = [] - assert node_a.get_network().notify(MyNotifee(events_a, "a")) - - # Add notifee for node_b - assert node_b.get_network().notify(MyNotifee(events_b, "b")) - - # start listen on node_b_swarm - await node_b.get_network().listen(node_b_multiaddr) - - node_b.set_stream_handler("/echo/1.0.0", my_stream_handler) - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in MyNotifee obj - # and that stream passed into opened_stream matches the stream created on - # node_a - assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] - - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_two_notifiers(): - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events0 = [] - assert node_a.get_network().notify(MyNotifee(events0, "0")) - - events1 = [] - assert node_a.get_network().notify(MyNotifee(events1, "1")) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in both Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - assert events0 == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] - assert events1 == [["connected1", stream.muxed_conn], ["opened_stream1", stream]] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_ten_notifiers(): - num_notifiers = 10 - - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events_lst = [] - for i in range(num_notifiers): - events_lst.append([]) - assert node_a.get_network().notify(MyNotifee(events_lst[i], str(i))) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in both Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - for i in range(num_notifiers): - assert events_lst[i] == [ - ["connected" + str(i), stream.muxed_conn], - ["opened_stream" + str(i), stream], - ] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_ten_notifiers_on_two_nodes(): - num_notifiers = 10 - events_lst_b = [] - - async def my_stream_handler(stream): - # Ensure the connected and opened_stream events were hit in all Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_b - for i in range(num_notifiers): - assert events_lst_b[i] == [ - ["connectedb" + str(i), stream.muxed_conn], - ["opened_streamb" + str(i), stream], - ] - while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() - - resp = ACK + read_string - await stream.write(resp.encode()) - - node_a, node_b = await perform_two_host_set_up(my_stream_handler) - - # Add notifee for node_a and node_b - events_lst_a = [] - for i in range(num_notifiers): - events_lst_a.append([]) - events_lst_b.append([]) - assert node_a.get_network().notify(MyNotifee(events_lst_a[i], "a" + str(i))) - assert node_b.get_network().notify(MyNotifee(events_lst_b[i], "b" + str(i))) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # Ensure the connected and opened_stream events were hit in all Notifee objs - # and that the stream passed into opened_stream matches the stream created on - # node_a - for i in range(num_notifiers): - assert events_lst_a[i] == [ - ["connecteda" + str(i), stream.muxed_conn], - ["opened_streama" + str(i), stream], - ] - - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. - - -@pytest.mark.asyncio -async def test_invalid_notifee(): - num_notifiers = 10 - - node_a, node_b = await perform_two_host_set_up() - - # Add notifee for node_a - events_lst = [] - for _ in range(num_notifiers): - events_lst.append([]) - assert not node_a.get_network().notify(InvalidNotifee()) - - stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) - - # If this point is reached, this implies that the InvalidNotifee instance - # did not assert false, i.e. no functions of InvalidNotifee were called (which is correct - # given that InvalidNotifee should not have been added as a notifee) - messages = ["hello", "hello"] - for message in messages: - expected_resp = ACK + message - await stream.write(message.encode()) - - response = (await stream.read(len(expected_resp))).decode() - - assert response == expected_resp - - # Success, terminate pending tasks. +async def test_notify(is_host_secure): + swarms = [SwarmFactory(is_host_secure) for _ in range(2)] + + events_0_0 = [] + events_1_0 = [] + events_0_without_listen = [] + swarms[0].register_notifee(MyNotifee(events_0_0)) + swarms[1].register_notifee(MyNotifee(events_1_0)) + # Listen + await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms]) + + swarms[0].register_notifee(MyNotifee(events_0_without_listen)) + + # Connected + await connect_swarm(swarms[0], swarms[1]) + # OpenedStream: first + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: second + await swarms[0].new_stream(swarms[1].get_peer_id()) + # OpenedStream: third, but different direction. + await swarms[1].new_stream(swarms[0].get_peer_id()) + + await asyncio.sleep(0.01) + + # TODO: Check `ClosedStream` and `ListenClose` events after they are ready. + + # Disconnected + await swarms[0].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + + # Connected again, but different direction. + await connect_swarm(swarms[1], swarms[0]) + await asyncio.sleep(0.01) + + # Disconnected again, but different direction. + await swarms[1].close_peer(swarms[0].get_peer_id()) + await asyncio.sleep(0.01) + + expected_events_without_listen = [ + Event.Connected, + Event.OpenedStream, + Event.OpenedStream, + Event.OpenedStream, + Event.Disconnected, + Event.Connected, + Event.Disconnected, + ] + expected_events = [Event.Listen] + expected_events_without_listen + + assert events_0_0 == expected_events + assert events_1_0 == expected_events + assert events_0_without_listen == expected_events_without_listen + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) From d61327f5f9449e9fbb9c3392c809e114c633f66f Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 17 Sep 2019 23:38:11 +0800 Subject: [PATCH 04/14] Add tests for SwarmConn --- libp2p/network/connection/swarm_connection.py | 13 ++++-- libp2p/network/stream/net_stream.py | 4 ++ tests/factories.py | 13 +++--- tests/network/conftest.py | 15 ++++++- tests/network/test_swarm_conn.py | 43 +++++++++++++++++++ 5 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 tests/network/test_swarm_conn.py diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 10b83a1..78e6ead 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -43,11 +43,15 @@ class SwarmConn(INetConn): # We *could* optimize this but it really isn't worth it. for stream in self.streams: await stream.reset() - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - asyncio.ensure_future(self._notify_disconnected()) for task in self._tasks: task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + asyncio.ensure_future(self._notify_disconnected()) async def _handle_new_streams(self) -> None: while True: @@ -70,7 +74,6 @@ class SwarmConn(INetConn): async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - # Call notifiers since event occurred for notifee in self.swarm.notifees: await notifee.opened_stream(self.swarm, net_stream) return net_stream @@ -91,3 +94,7 @@ class SwarmConn(INetConn): async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) + + # TODO: Called by `Stream` whenever it is time to remove the stream. + def remove_stream(self, stream: NetStream) -> None: + self.streams.remove(stream) diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 3ae7c9c..018ef6d 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -66,3 +66,7 @@ class NetStream(INetStream): async def reset(self) -> None: await self.muxed_stream.reset() + + # TODO: `remove`: Called by close and write when the stream is in specific states. + # It notify `ClosedStream` after `SwarmConn.remove_stream` is called. + # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 diff --git a/tests/factories.py b/tests/factories.py index e39b12d..af4d529 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -6,6 +6,7 @@ import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub @@ -128,11 +129,13 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: return hosts[0], hosts[1] -# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: -# host_0, host_1 = await host_pair_factory() -# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] -# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] -# return mplex_conn_0, host_0, mplex_conn_1, host_1 +async def swarm_conn_pair_factory( + is_secure +) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: + swarms = await swarm_pair_factory(is_secure) + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + return conn_0, swarms[0], conn_1, swarms[1] async def net_stream_pair_factory( diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 47d5c5f..018e822 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,7 +2,11 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory, swarm_pair_factory +from tests.factories import ( + net_stream_pair_factory, + swarm_conn_pair_factory, + swarm_pair_factory, +) @pytest.fixture @@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure): yield swarm_0, swarm_1 finally: await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def swarm_conn_pair(is_host_secure): + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) + try: + yield conn_0, conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py new file mode 100644 index 0000000..f9974e1 --- /dev/null +++ b/tests/network/test_swarm_conn.py @@ -0,0 +1,43 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_swarm_conn_close(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert not conn_0.event_closed.is_set() + assert not conn_1.event_closed.is_set() + + await conn_0.close() + + await asyncio.sleep(0.01) + + assert conn_0.event_closed.is_set() + assert conn_1.event_closed.is_set() + assert conn_0 not in conn_0.swarm.connections.values() + assert conn_1 not in conn_1.swarm.connections.values() + + +@pytest.mark.asyncio +async def test_swarm_conn_streams(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert len(await conn_0.get_streams()) == 0 + assert len(await conn_1.get_streams()) == 0 + + stream_0_0 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 1 + assert len(await conn_1.get_streams()) == 1 + + stream_0_1 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 2 + assert len(await conn_1.get_streams()) == 2 + + conn_0.remove_stream(stream_0_0) + assert len(await conn_0.get_streams()) == 1 + conn_0.remove_stream(stream_0_1) + assert len(await conn_0.get_streams()) == 0 From a9ad37bc6f8a715086feac6262b42dbf978e079e Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 18 Sep 2019 15:44:45 +0800 Subject: [PATCH 05/14] Add mplex tests and fix error in `SwarmConn.close` --- libp2p/__init__.py | 8 ++-- libp2p/host/basic_host.py | 2 +- libp2p/network/connection/swarm_connection.py | 13 ++++++- libp2p/security/security_multistream.py | 11 +++--- libp2p/stream_muxer/mplex/mplex.py | 3 -- libp2p/stream_muxer/mplex/mplex_stream.py | 1 + libp2p/stream_muxer/muxer_multistream.py | 19 ++++----- libp2p/transport/typing.py | 9 ++++- libp2p/transport/upgrader.py | 11 ++---- tests/factories.py | 39 ++++++++++++++----- tests/network/test_swarm_conn.py | 2 + tests/stream_muxer/__init__.py | 0 tests/stream_muxer/conftest.py | 16 ++++++++ tests/stream_muxer/test_mplex_conn.py | 6 +++ 14 files changed, 96 insertions(+), 44 deletions(-) create mode 100644 tests/stream_muxer/__init__.py create mode 100644 tests/stream_muxer/conftest.py create mode 100644 tests/stream_muxer/test_mplex_conn.py diff --git a/libp2p/__init__.py b/libp2p/__init__.py index b4d2a9a..cbff1e4 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -17,8 +17,8 @@ from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTr import libp2p.security.secio.transport as secio from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex -from libp2p.stream_muxer.muxer_multistream import MuxerClassType from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol @@ -74,8 +74,8 @@ def initialize_default_swarm( key_pair: KeyPair, id_opt: ID = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, - sec_opt: Mapping[TProtocol, ISecureTransport] = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, ) -> Swarm: @@ -114,7 +114,7 @@ async def new_node( key_pair: KeyPair = None, swarm_opt: INetwork = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, MuxerClassType] = None, + muxer_opt: Mapping[TProtocol, TMuxerClass] = None, sec_opt: Mapping[TProtocol, ISecureTransport] = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 5f0ccfd..912d3ef 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -141,4 +141,4 @@ class BasicHost(IHost): MultiselectCommunicator(net_stream) ) net_stream.set_protocol(protocol) - asyncio.ensure_future(handler(net_stream)) + await handler(net_stream) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 78e6ead..6714bb8 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -66,10 +66,19 @@ class SwarmConn(INetConn): await self.close() + async def _call_stream_handler(self, net_stream: NetStream) -> None: + try: + await self.swarm.common_stream_handler(net_stream) + # TODO: More exact exceptions + except Exception: + # TODO: Emit logs. + # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + self.remove_stream(net_stream) + async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self.run_task(self.swarm.common_stream_handler(net_stream)) + await self.run_task(self._call_stream_handler(net_stream)) async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) @@ -97,4 +106,6 @@ class SwarmConn(INetConn): # TODO: Called by `Stream` whenever it is time to remove the stream. def remove_stream(self, stream: NetStream) -> None: + if stream not in self.streams: + return self.streams.remove(stream) diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 466d60a..06f4b8a 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -1,6 +1,5 @@ from abc import ABC from collections import OrderedDict -from typing import Mapping from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -9,6 +8,7 @@ from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.transport.typing import TSecurityOptions from libp2p.typing import TProtocol @@ -31,15 +31,14 @@ class SecurityMultistream(ABC): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] - ) -> None: + def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - for protocol, transport in secure_transports_by_protocol.items(): - self.add_transport(protocol, transport) + if secure_transports_by_protocol is not None: + for protocol, transport in secure_transports_by_protocol.items(): + self.add_transport(protocol, transport) def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None: """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6781fed..f660226 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -29,9 +29,6 @@ class Mplex(IMuxedConn): secured_conn: ISecureConn peer_id: ID - # TODO: `dataIn` in go implementation. Should be size of 8. - # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies - # to let the `MplexStream`s know that EOF arrived (#235). next_channel_id: int streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 06b90fa..221e238 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -24,6 +24,7 @@ class MplexStream(IMuxedStream): close_lock: asyncio.Lock + # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data: "asyncio.Queue[bytes]" event_local_closed: asyncio.Event diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 806c90d..d506749 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -1,5 +1,4 @@ from collections import OrderedDict -from typing import Mapping, Type from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID @@ -7,12 +6,11 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.security.secure_conn_interface import ISecureConn +from libp2p.transport.typing import TMuxerClass, TMuxerOptions from libp2p.typing import TProtocol from .abc import IMuxedConn -MuxerClassType = Type[IMuxedConn] - # FIXME: add negotiate timeout to `MuxerMultistream` DEFAULT_NEGOTIATE_TIMEOUT = 60 @@ -24,20 +22,19 @@ class MuxerMultistream: """ # NOTE: Can be changed to `typing.OrderedDict` since Python 3.7.2. - transports: "OrderedDict[TProtocol, MuxerClassType]" + transports: "OrderedDict[TProtocol, TMuxerClass]" multiselect: Multiselect multiselect_client: MultiselectClient - def __init__( - self, muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType] - ) -> None: + def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - for protocol, transport in muxer_transports_by_protocol.items(): - self.add_transport(protocol, transport) + if muxer_transports_by_protocol is not None: + for protocol, transport in muxer_transports_by_protocol.items(): + self.add_transport(protocol, transport) - def add_transport(self, protocol: TProtocol, transport: MuxerClassType) -> None: + def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None: """ Add a protocol and its corresponding transport to multistream-select(multiselect). The order that a protocol is added is exactly the precedence it is negotiated in @@ -51,7 +48,7 @@ class MuxerMultistream: self.transports[protocol] = transport self.multiselect.add_handler(protocol, None) - async def select_transport(self, conn: IRawConnection) -> MuxerClassType: + async def select_transport(self, conn: IRawConnection) -> TMuxerClass: """ Select a transport that both us and the node on the other end of conn support and agree on diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 6d0047c..f9b31dc 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,4 +1,11 @@ from asyncio import StreamReader, StreamWriter -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Mapping, Type + +from libp2p.security.secure_transport_interface import ISecureTransport +from libp2p.stream_muxer.abc import IMuxedConn +from libp2p.typing import TProtocol THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] +TSecurityOptions = Mapping[TProtocol, ISecureTransport] +TMuxerClass = Type[IMuxedConn] +TMuxerOptions = Mapping[TProtocol, TMuxerClass] diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 233c4d5..877fd23 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,19 +1,16 @@ -from typing import Mapping - from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.security.secure_conn_interface import ISecureConn -from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.security_multistream import SecurityMultistream from libp2p.stream_muxer.abc import IMuxedConn -from libp2p.stream_muxer.muxer_multistream import MuxerClassType, MuxerMultistream +from libp2p.stream_muxer.muxer_multistream import MuxerMultistream from libp2p.transport.exceptions import ( HandshakeFailure, MuxerUpgradeFailure, SecurityUpgradeFailure, ) -from libp2p.typing import TProtocol +from libp2p.transport.typing import TMuxerOptions, TSecurityOptions from .listener_interface import IListener from .transport_interface import ITransport @@ -25,8 +22,8 @@ class TransportUpgrader: def __init__( self, - secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport], - muxer_transports_by_protocol: Mapping[TProtocol, MuxerClassType], + secure_transports_by_protocol: TSecurityOptions, + muxer_transports_by_protocol: TMuxerOptions, ): self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) diff --git a/tests/factories.py b/tests/factories.py index af4d529..dcc9a85 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -15,6 +15,8 @@ from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.transport.typing import TMuxerOptions from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -34,10 +36,10 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def SwarmFactory(is_secure: bool) -> Swarm: +def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm: key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(False, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt) class ListeningSwarmFactory(factory.Factory): @@ -45,17 +47,22 @@ class ListeningSwarmFactory(factory.Factory): model = Swarm @classmethod - async def create_and_listen(cls, is_secure: bool) -> Swarm: - swarm = SwarmFactory(is_secure) + async def create_and_listen( + cls, is_secure: bool, muxer_opt: TMuxerOptions = None + ) -> Swarm: + swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt) await swarm.listen(LISTEN_MADDR) return swarm @classmethod async def create_batch_and_listen( - cls, is_secure: bool, number: int + cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, ...]: return await asyncio.gather( - *[cls.create_and_listen(is_secure) for _ in range(number)] + *[ + cls.create_and_listen(is_secure, muxer_opt=muxer_opt) + for _ in range(number) + ] ) @@ -112,8 +119,12 @@ class PubsubFactory(factory.Factory): cache_size = None -async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) +async def swarm_pair_factory( + is_secure: bool, muxer_opt: TMuxerOptions = None +) -> Tuple[Swarm, Swarm]: + swarms = await ListeningSwarmFactory.create_batch_and_listen( + is_secure, 2, muxer_opt=muxer_opt + ) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] @@ -130,7 +141,7 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: async def swarm_conn_pair_factory( - is_secure + is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: swarms = await swarm_pair_factory(is_secure) conn_0 = swarms[0].connections[swarms[1].get_peer_id()] @@ -138,6 +149,14 @@ async def swarm_conn_pair_factory( return conn_0, swarms[0], conn_1, swarms[1] +async def mplex_conn_pair_factory(is_secure): + muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( + is_secure, muxer_opt=muxer_opt + ) + return conn_0.conn, swarm_0, conn_1.conn, swarm_1 + + async def net_stream_pair_factory( is_secure: bool ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index f9974e1..2abc7d0 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -41,3 +41,5 @@ async def test_swarm_conn_streams(swarm_conn_pair): assert len(await conn_0.get_streams()) == 1 conn_0.remove_stream(stream_0_1) assert len(await conn_0.get_streams()) == 0 + # Nothing happen if `stream_0_1` is not present or already removed. + conn_0.remove_stream(stream_0_1) diff --git a/tests/stream_muxer/__init__.py b/tests/stream_muxer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py new file mode 100644 index 0000000..3695ae4 --- /dev/null +++ b/tests/stream_muxer/conftest.py @@ -0,0 +1,16 @@ +import asyncio + +import pytest + +from tests.factories import mplex_conn_pair_factory + + +@pytest.fixture +async def mplex_conn_pair(is_host_secure): + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_host_secure + ) + try: + yield mplex_conn_0, mplex_conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py new file mode 100644 index 0000000..a85d9f4 --- /dev/null +++ b/tests/stream_muxer/test_mplex_conn.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.asyncio +async def test_mplex_conn(mplex_conn_pair): + conn_0, conn_1 = mplex_conn_pair From 02c55e5d1468ad5ddabca4e4a2a1458dd20c436f Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 18 Sep 2019 17:22:04 +0800 Subject: [PATCH 06/14] Add tests for `MplexConn` --- libp2p/host/basic_host.py | 1 - libp2p/stream_muxer/mplex/mplex.py | 1 - tests/network/test_net_stream.py | 3 -- tests/stream_muxer/conftest.py | 2 ++ tests/stream_muxer/test_mplex_conn.py | 44 +++++++++++++++++++++++++++ 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 912d3ef..bfe202c 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,4 +1,3 @@ -import asyncio from typing import List, Sequence import multiaddr diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f660226..10bf653 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -38,7 +38,6 @@ class Mplex(IMuxedConn): _tasks: List["asyncio.Future[Any]"] - # TODO: `generic_protocol_handler` should be refactored out of mplex conn. def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ create a new muxed connection diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index 80bed6c..c748837 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -7,9 +7,6 @@ from tests.constants import MAX_READ_LEN DATA = b"data_123" -# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its -# own file, after `generic_protocol_handler` is refactored out of `Mplex`. - @pytest.mark.asyncio async def test_net_stream_read_write(net_stream_pair): diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index 3695ae4..b05a016 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -10,6 +10,8 @@ async def mplex_conn_pair(is_host_secure): mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( is_host_secure ) + assert mplex_conn_0.initiator + assert not mplex_conn_1.initiator try: yield mplex_conn_0, mplex_conn_1 finally: diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index a85d9f4..6dc98ad 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,6 +1,50 @@ +import asyncio + import pytest @pytest.mark.asyncio async def test_mplex_conn(mplex_conn_pair): conn_0, conn_1 = mplex_conn_pair + + assert len(conn_0.streams) == 0 + assert len(conn_1.streams) == 0 + assert not conn_0.event_shutting_down.is_set() + assert not conn_1.event_shutting_down.is_set() + assert not conn_0.event_closed.is_set() + assert not conn_1.event_closed.is_set() + + # Test: Open a stream, and both side get 1 more stream. + stream_0 = await conn_0.open_stream() + await asyncio.sleep(0.01) + assert len(conn_0.streams) == 1 + assert len(conn_1.streams) == 1 + # Test: From another side. + stream_1 = await conn_1.open_stream() + await asyncio.sleep(0.01) + assert len(conn_0.streams) == 2 + assert len(conn_1.streams) == 2 + + # Close from one side. + await conn_0.close() + # Sleep for a while for both side to handle `close`. + await asyncio.sleep(0.01) + # Test: Both side is closed. + assert conn_0.event_shutting_down.is_set() + assert conn_0.event_closed.is_set() + assert conn_1.event_shutting_down.is_set() + assert conn_1.event_closed.is_set() + # Test: All streams should have been closed. + assert stream_0.event_remote_closed.is_set() + assert stream_0.event_reset.is_set() + assert stream_0.event_local_closed.is_set() + assert conn_0.streams is None + # Test: All streams on the other side are also closed. + assert stream_1.event_remote_closed.is_set() + assert stream_1.event_reset.is_set() + assert stream_1.event_local_closed.is_set() + assert conn_1.streams is None + + # Test: No effect to close more than once between two side. + await conn_0.close() + await conn_1.close() From 313ae45b45fe8d43935278353de755fd1f1d56c2 Mon Sep 17 00:00:00 2001 From: mhchia Date: Wed, 18 Sep 2019 21:51:09 +0800 Subject: [PATCH 07/14] Add tests for `MplexStream` --- tests/factories.py | 19 ++- tests/network/test_net_stream.py | 2 - tests/stream_muxer/conftest.py | 13 +- tests/stream_muxer/test_mplex_stream.py | 182 ++++++++++++++++++++++++ 4 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tests/stream_muxer/test_mplex_stream.py diff --git a/tests/factories.py b/tests/factories.py index dcc9a85..cecc656 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -16,6 +16,7 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.transport.typing import TMuxerOptions from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR @@ -149,7 +150,7 @@ async def swarm_conn_pair_factory( return conn_0, swarms[0], conn_1, swarms[1] -async def mplex_conn_pair_factory(is_secure): +async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( is_secure, muxer_opt=muxer_opt @@ -157,6 +158,22 @@ async def mplex_conn_pair_factory(is_secure): return conn_0.conn, swarm_0, conn_1.conn, swarm_1 +async def mplex_stream_pair_factory( + is_secure: bool +) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_secure + ) + stream_0 = await mplex_conn_0.open_stream() + await asyncio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any stream upon connection") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + return stream_0, swarm_0, stream_1, swarm_1 + + async def net_stream_pair_factory( is_secure: bool ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index c748837..9229069 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -53,11 +53,9 @@ async def test_net_stream_read_until_eof(net_stream_pair): @pytest.mark.asyncio async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair - assert not stream_1.muxed_stream.event_remote_closed.is_set() await stream_0.write(DATA) await stream_0.close() await asyncio.sleep(0.01) - assert stream_1.muxed_stream.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index b05a016..b1d6c11 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -2,7 +2,7 @@ import asyncio import pytest -from tests.factories import mplex_conn_pair_factory +from tests.factories import mplex_conn_pair_factory, mplex_stream_pair_factory @pytest.fixture @@ -16,3 +16,14 @@ async def mplex_conn_pair(is_host_secure): yield mplex_conn_0, mplex_conn_1 finally: await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def mplex_stream_pair(is_host_secure): + mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( + is_host_secure + ) + try: + yield mplex_stream_0, mplex_stream_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py new file mode 100644 index 0000000..e2bcb24 --- /dev/null +++ b/tests/stream_muxer/test_mplex_stream.py @@ -0,0 +1,182 @@ +import asyncio + +import pytest + +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from tests.constants import MAX_READ_LEN + +DATA = b"data_123" + + +@pytest.mark.asyncio +async def test_mplex_stream_read_write(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): + read_bytes = bytearray() + stream_0, stream_1 = mplex_stream_pair + + async def read_until_eof(): + read_bytes.extend(await stream_1.read()) + + task = asyncio.ensure_future(read_until_eof()) + + expected_data = bytearray() + + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await asyncio.sleep(0.01) + assert read_bytes == expected_data + + task.cancel() + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + assert not stream_1.event_remote_closed.is_set() + await stream_0.write(DATA) + await stream_0.close() + await asyncio.sleep(0.01) + assert stream_1.event_remote_closed.is_set() + assert (await stream_1.read(MAX_READ_LEN)) == DATA + with pytest.raises(MplexStreamEOF): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamReset): + await stream_0.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_1.reset() + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_both_close(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Flags are not set initially. + assert not stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert not stream_1.event_remote_closed.is_set() + # Streams are present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close one side. + await stream_0.close() + await asyncio.sleep(0.01) + + assert stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are still present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close the other side. + await stream_1.close() + await asyncio.sleep(0.01) + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # Test: Reset after both close. + await stream_0.reset() + + +@pytest.mark.asyncio +async def test_mplex_stream_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + await asyncio.sleep(0.01) + + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # `close` should do nothing. + await stream_0.close() + await stream_1.close() + # `reset` should do nothing as well. + await stream_0.reset() + await stream_1.reset() From 62b0bc4580265310d4ee6b1e60d1830b26fac15f Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 19 Sep 2019 14:10:50 +0800 Subject: [PATCH 08/14] Remove useless protocol_ids in logging --- libp2p/network/swarm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 18218d6..6bdeb8f 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -137,11 +137,7 @@ class Swarm(INetwork): :param protocol_id: protocol id :return: net stream instance """ - logger.debug( - "attempting to open a stream to peer %s, over one of the protocols %s", - peer_id, - protocol_ids, - ) + logger.debug("attempting to open a stream to peer %s", peer_id) swarm_conn = await self.dial_peer(peer_id) From 8d2415a404904833614fcf0db47981376b3f3fe3 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 23 Sep 2019 15:01:58 +0800 Subject: [PATCH 09/14] Move calls to `Notifee` inside `Swarm` --- libp2p/network/connection/swarm_connection.py | 17 ++++---- libp2p/network/stream/net_stream.py | 2 +- libp2p/network/swarm.py | 42 ++++++++++++++----- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 6714bb8..cf1dc9e 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -51,7 +51,7 @@ class SwarmConn(INetConn): except asyncio.CancelledError: pass # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - asyncio.ensure_future(self._notify_disconnected()) + self._notify_disconnected() async def _handle_new_streams(self) -> None: while True: @@ -76,20 +76,18 @@ class SwarmConn(INetConn): self.remove_stream(net_stream) async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = await self._add_stream(muxed_stream) + net_stream = self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: await self.run_task(self._call_stream_handler(net_stream)) - async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - for notifee in self.swarm.notifees: - await notifee.opened_stream(self.swarm, net_stream) + self.swarm.notify_opened_stream(net_stream) return net_stream - async def _notify_disconnected(self) -> None: - for notifee in self.swarm.notifees: - await notifee.disconnected(self.swarm, self) + def _notify_disconnected(self) -> None: + self.swarm.notify_disconnected(self) async def start(self) -> None: await self.run_task(self._handle_new_streams()) @@ -99,12 +97,11 @@ class SwarmConn(INetConn): async def new_stream(self) -> NetStream: muxed_stream = await self.conn.open_stream() - return await self._add_stream(muxed_stream) + return self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) - # TODO: Called by `Stream` whenever it is time to remove the stream. def remove_stream(self, stream: NetStream) -> None: if stream not in self.streams: return diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 018ef6d..625a288 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -68,5 +68,5 @@ class NetStream(INetStream): await self.muxed_stream.reset() # TODO: `remove`: Called by close and write when the stream is in specific states. - # It notify `ClosedStream` after `SwarmConn.remove_stream` is called. + # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called. # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 4fd3bbc..bed204e 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -223,8 +223,7 @@ class Swarm(INetwork): await listener.listen(maddr) # Call notifiers since event occurred - for notifee in self.notifees: - await notifee.listen(self, maddr) + self.notify_listen(maddr) return True except IOError: @@ -234,13 +233,6 @@ class Swarm(INetwork): # No maddr succeeded return False - def register_notifee(self, notifee: INotifee) -> None: - """ - :param notifee: object implementing Notifee interface - :return: true if notifee registered successfully, false otherwise - """ - self.notifees.append(notifee) - def add_router(self, router: IPeerRouting) -> None: self.router = router @@ -279,8 +271,7 @@ class Swarm(INetwork): # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - for notifee in self.notifees: - await notifee.connected(self, swarm_conn) + self.notify_connected(swarm_conn) await swarm_conn.start() return swarm_conn @@ -294,3 +285,32 @@ class Swarm(INetwork): # TODO: Should be changed to remove the exact connection, # if we have several connections per peer in the future. del self.connections[peer_id] + + # Notifee + + # TODO: Remeber the spawn notifying tasks and clean them up when closing. + + def register_notifee(self, notifee: INotifee) -> None: + """ + :param notifee: object implementing Notifee interface + :return: true if notifee registered successfully, false otherwise + """ + self.notifees.append(notifee) + + def notify_opened_stream(self, stream: INetStream) -> None: + asyncio.gather( + *[notifee.opened_stream(self, stream) for notifee in self.notifees] + ) + + # TODO: `notify_closed_stream` + + def notify_connected(self, conn: INetConn) -> None: + asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) + + def notify_disconnected(self, conn: INetConn) -> None: + asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) + + def notify_listen(self, multiaddr: Multiaddr) -> None: + asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) + + # TODO: `notify_listen_close` From 92deae41dcdbaf6477c543474460a629545267ec Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 23 Sep 2019 15:46:50 +0800 Subject: [PATCH 10/14] Change `SwarmConn.conn` to `muxed_conn` --- .../network/connection/net_connection_interface.py | 2 +- libp2p/network/connection/swarm_connection.py | 12 ++++++------ libp2p/network/stream/net_stream.py | 4 +++- libp2p/network/swarm.py | 2 +- libp2p/pubsub/pubsub_notifee.py | 2 +- tests/factories.py | 2 +- tests/security/test_security_multistream.py | 4 ++-- 7 files changed, 15 insertions(+), 13 deletions(-) diff --git a/libp2p/network/connection/net_connection_interface.py b/libp2p/network/connection/net_connection_interface.py index c2c6285..e308ad6 100644 --- a/libp2p/network/connection/net_connection_interface.py +++ b/libp2p/network/connection/net_connection_interface.py @@ -7,7 +7,7 @@ from libp2p.stream_muxer.abc import IMuxedConn class INetConn(Closer): - conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod async def new_stream(self) -> INetStream: diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index cf1dc9e..e25d75f 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -16,15 +16,15 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee class SwarmConn(INetConn): - conn: IMuxedConn + muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] event_closed: asyncio.Event _tasks: List["asyncio.Future[Any]"] - def __init__(self, conn: IMuxedConn, swarm: "Swarm") -> None: - self.conn = conn + def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: + self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() self.event_closed = asyncio.Event() @@ -37,7 +37,7 @@ class SwarmConn(INetConn): self.event_closed.set() self.swarm.remove_conn(self) - await self.conn.close() + await self.muxed_conn.close() # This is just for cleaning up state. The connection has already been closed. # We *could* optimize this but it really isn't worth it. @@ -56,7 +56,7 @@ class SwarmConn(INetConn): async def _handle_new_streams(self) -> None: while True: try: - stream = await self.conn.accept_stream() + stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, # we should break the loop and close the connection. @@ -96,7 +96,7 @@ class SwarmConn(INetConn): self._tasks.append(asyncio.ensure_future(coro)) async def new_stream(self) -> NetStream: - muxed_stream = await self.conn.open_stream() + muxed_stream = await self.muxed_conn.open_stream() return self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 625a288..0142721 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,3 +1,5 @@ +from typing import Optional + from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, @@ -16,7 +18,7 @@ from .net_stream_interface import INetStream class NetStream(INetStream): muxed_stream: IMuxedStream - protocol_id: TProtocol + protocol_id: Optional[TProtocol] def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index bed204e..9d507fb 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -279,7 +279,7 @@ class Swarm(INetwork): """ Simply remove the connection from Swarm's records, without closing the connection. """ - peer_id = swarm_conn.conn.peer_id + peer_id = swarm_conn.muxed_conn.peer_id if peer_id not in self.connections: return # TODO: Should be changed to remove the exact connection, diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 627152e..85c0bd8 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -36,7 +36,7 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.conn.peer_id) + await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) async def disconnected(self, network: INetwork, conn: INetConn) -> None: pass diff --git a/tests/factories.py b/tests/factories.py index cecc656..b4e8be2 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -155,7 +155,7 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( is_secure, muxer_opt=muxer_opt ) - return conn_0.conn, swarm_0, conn_1.conn, swarm_1 + return conn_0.muxed_conn, swarm_0, conn_1.muxed_conn, swarm_1 async def mplex_stream_pair_factory( diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 26d3140..a9fe031 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -53,8 +53,8 @@ async def perform_simple_test( node2_conn = node2.get_network().connections[peer_id_for_node(node1)] # Perform assertion - assertion_func(node1_conn.conn.secured_conn) - assertion_func(node2_conn.conn.secured_conn) + assertion_func(node1_conn.muxed_conn.secured_conn) + assertion_func(node2_conn.muxed_conn.secured_conn) # Success, terminate pending tasks. From 95ae718e3d135bac2e4f7cd709da6991315c393e Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 23 Sep 2019 16:01:22 +0800 Subject: [PATCH 11/14] Raise `ParseError` in `read_delim` --- libp2p/protocol_muxer/multiselect_communicator.py | 8 ++++---- libp2p/utils.py | 5 ++--- tests/interop/test_bindings.py | 1 - 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index a66a564..6f7b715 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -20,10 +20,10 @@ class MultiselectCommunicator(IMultiselectCommunicator): msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) - except IOException: + except IOException as error: raise MultiselectCommunicatorError( "fail to write to multiselect communicator" - ) + ) from error async def read(self) -> str: """ @@ -32,8 +32,8 @@ class MultiselectCommunicator(IMultiselectCommunicator): try: data = await read_delim(self.read_writer) # `IOException` includes `IncompleteReadError` and `StreamError` - except (ParseError, IOException, ValueError): + except (ParseError, IOException) as error: raise MultiselectCommunicatorError( "fail to read from multiselect communicator" - ) + ) from error return data.decode() diff --git a/libp2p/utils.py b/libp2p/utils.py index 0e15b56..4844f0e 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -73,9 +73,8 @@ def encode_delim(msg: bytes) -> bytes: async def read_delim(reader: Reader) -> bytes: msg_bytes = await read_varint_prefixed_bytes(reader) - # TODO: Investigate if it is possible to have empty `msg_bytes` - if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n": - raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') + if len(msg_bytes) == 0 or msg_bytes[-1:] != b"\n": + raise ParseError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') return msg_bytes[:-1] diff --git a/tests/interop/test_bindings.py b/tests/interop/test_bindings.py index 1e78ff4..dc0a270 100644 --- a/tests/interop/test_bindings.py +++ b/tests/interop/test_bindings.py @@ -22,6 +22,5 @@ async def test_connect(hosts, p2pds): assert len(host.get_network().connections) == 1 # Test: `disconnect` from Go await p2pd.control.disconnect(host.get_id()) - # FIXME: Failed to handle disconnect await asyncio.sleep(0.01) assert len(host.get_network().connections) == 0 From 1bd18c84f24e63694bb81aad93667fa9dc43e7ce Mon Sep 17 00:00:00 2001 From: Kevin Mai-Husan Chia Date: Tue, 24 Sep 2019 12:33:14 +0800 Subject: [PATCH 12/14] Apply suggestions from code review Co-Authored-By: Alex Stokes --- libp2p/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index cbff1e4..9e452ca 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -114,8 +114,8 @@ async def new_node( key_pair: KeyPair = None, swarm_opt: INetwork = None, transport_opt: Sequence[str] = None, - muxer_opt: Mapping[TProtocol, TMuxerClass] = None, - sec_opt: Mapping[TProtocol, ISecureTransport] = None, + muxer_opt: TMuxerOptions = None, + sec_opt: TSecurityOptions = None, peerstore_opt: IPeerStore = None, disc_opt: IPeerRouting = None, ) -> BasicHost: From 37bee9fb160a64fab95246ca5fff6357f889faf4 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Sep 2019 12:51:59 +0800 Subject: [PATCH 13/14] PR feedback - Use `TMuxerOptions` and `TSecurityOptions` in libp2p/__init__.py - Remove the default value for `muxer_transports_by_protocol` in `MuxerMultistream` and `secure_transports_by_protocol` `SecureMultistream` --- libp2p/__init__.py | 5 ++--- libp2p/security/security_multistream.py | 7 +++---- libp2p/stream_muxer/muxer_multistream.py | 7 +++---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 9e452ca..08caf25 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,5 +1,5 @@ import asyncio -from typing import Mapping, Sequence +from typing import Sequence from libp2p.crypto.keys import KeyPair from libp2p.crypto.rsa import create_new_key_pair @@ -15,10 +15,9 @@ from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio -from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.typing import TMuxerClass, TMuxerOptions, TSecurityOptions +from libp2p.transport.typing import TMuxerOptions, TSecurityOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 06f4b8a..cff55af 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -31,14 +31,13 @@ class SecurityMultistream(ABC): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__(self, secure_transports_by_protocol: TSecurityOptions = None) -> None: + def __init__(self, secure_transports_by_protocol: TSecurityOptions) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - if secure_transports_by_protocol is not None: - for protocol, transport in secure_transports_by_protocol.items(): - self.add_transport(protocol, transport) + for protocol, transport in secure_transports_by_protocol.items(): + self.add_transport(protocol, transport) def add_transport(self, protocol: TProtocol, transport: ISecureTransport) -> None: """ diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index d506749..7f6ee07 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -26,13 +26,12 @@ class MuxerMultistream: multiselect: Multiselect multiselect_client: MultiselectClient - def __init__(self, muxer_transports_by_protocol: TMuxerOptions = None) -> None: + def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multiselect_client = MultiselectClient() - if muxer_transports_by_protocol is not None: - for protocol, transport in muxer_transports_by_protocol.items(): - self.add_transport(protocol, transport) + for protocol, transport in muxer_transports_by_protocol.items(): + self.add_transport(protocol, transport) def add_transport(self, protocol: TProtocol, transport: TMuxerClass) -> None: """ From 7405f078e655b19ff2302727647c8bcc5ba1b8ad Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 24 Sep 2019 13:22:25 +0800 Subject: [PATCH 14/14] Raise `read_delim` exception with different msgs Separate `len(msg_bytes) == 0` and `msg_bytes[-1:] != b"\n"`, to raise `ParseError` with different messages. --- libp2p/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libp2p/utils.py b/libp2p/utils.py index 4844f0e..39c79e5 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -73,8 +73,12 @@ def encode_delim(msg: bytes) -> bytes: async def read_delim(reader: Reader) -> bytes: msg_bytes = await read_varint_prefixed_bytes(reader) - if len(msg_bytes) == 0 or msg_bytes[-1:] != b"\n": - raise ParseError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') + if len(msg_bytes) == 0: + raise ParseError(f"`len(msg_bytes)` should not be 0") + if msg_bytes[-1:] != b"\n": + raise ParseError( + f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes}' + ) return msg_bytes[:-1]