diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e6bf288..daabb64 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -5,6 +5,7 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.peer.id import id_b58_decode from .network_interface import INetwork +from .notifee_interface import INotifee from .stream.net_stream import NetStream from .connection.raw_connection import RawConnection @@ -175,8 +176,8 @@ class Swarm(INetwork): """ :param notifee: object implementing Notifee interface """ - # TODO: Add check to ensure notifee conforms to Notifee interface - self.notifees.append(notifee) + if isinstance(notifee, INotifee): + self.notifees.append(notifee) def add_transport(self, transport): # TODO: Support more than one transport @@ -195,6 +196,10 @@ def create_generic_protocol_handler(swarm): # Perform protocol muxing to determine protocol to use _, handler = await multiselect.negotiate(muxed_stream) + # Call notifiers since event occurred + for notifee in swarm.notifees: + await notifee.opened_stream(swarm, muxed_stream) + # Give to stream handler asyncio.ensure_future(handler(muxed_stream)) diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index fdf7368..37b287a 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -44,6 +44,27 @@ class MyNotifee(INotifee): async def listen_close(self, network, multiaddr): pass +class InvalidNotifee(): + # pylint: disable=too-many-instance-attributes, cell-var-from-loop + + def __init__(self): + pass + + async def opened_stream(self, network, stream): + assert False + + async def closed_stream(self, network, stream): + assert False + + async def connected(self, network, conn): + assert False + + async def disconnected(self, network, conn): + assert False + + async def listen(self, network, multiaddr): + assert False + async def perform_two_host_simple_set_up(): node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) @@ -61,6 +82,16 @@ async def perform_two_host_simple_set_up(): node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) return node_a, node_b +async def perform_two_host_simple_set_up_custom_handler(handler): + node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) + node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) + + node_b.set_stream_handler("/echo/1.0.0", handler) + + # Associate the peer with local ip address (see default parameters of Libp2p()) + node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) + return node_a, node_b + @pytest.mark.asyncio async def test_one_notifier(): node_a, node_b = await perform_two_host_simple_set_up() @@ -88,6 +119,47 @@ async def test_one_notifier(): # Success, terminate pending tasks. await cleanup() +@pytest.mark.asyncio +async def test_one_notifier_on_two_nodes(): + events_b = [] + + async def my_stream_handler(stream): + assert events_b == [["connectedb", stream.mplex_conn], \ + ["opened_streamb", stream]] + while True: + read_string = (await stream.read()).decode() + + resp = "ack:" + read_string + await stream.write(resp.encode()) + + node_a, node_b = await perform_two_host_simple_set_up_custom_handler(my_stream_handler) + + # Add notifee for node_a + events_a = [] + node_a.get_network().notify(MyNotifee(events_a, "a")) + + # Add notifee for node_b + node_b.get_network().notify(MyNotifee(events_b, "b")) + + stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) + + # Ensure the connected and opened_stream events were hit in MyNotifee obj + # and that stream passed into opened_stream matches the stream created on + # node_a + assert events_a == [["connecteda", stream.mplex_conn], \ + ["opened_streama", stream]] + + messages = ["hello", "hello"] + for message in messages: + await stream.write(message.encode()) + + response = (await stream.read()).decode() + + assert response == ("ack:" + message) + + # Success, terminate pending tasks. + await cleanup() + @pytest.mark.asyncio async def test_two_notifiers(): node_a, node_b = await perform_two_host_simple_set_up() @@ -150,3 +222,31 @@ async def test_ten_notifiers(): # Success, terminate pending tasks. await cleanup() + +@pytest.mark.asyncio +async def test_invalid_notifee(): + num_notifiers = 10 + + node_a, node_b = await perform_two_host_simple_set_up() + + # Add notifee for node_a + events_lst = [] + for i in range(num_notifiers): + events_lst.append([]) + node_a.get_network().notify(InvalidNotifee()) + + stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) + + # If this point is reached, this implies that the InvalidNotifee instance + # did not assert false, i.e. no functions of InvalidNotifee were called (which is correct + # given that InvalidNotifee should not have been added as a notifee) + messages = ["hello", "hello"] + for message in messages: + await stream.write(message.encode()) + + response = (await stream.read()).decode() + + assert response == ("ack:" + message) + + # Success, terminate pending tasks. + await cleanup()