diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index 37b287a..108f04a 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -50,19 +50,19 @@ class InvalidNotifee(): def __init__(self): pass - async def opened_stream(self, network, stream): + async def opened_stream(self): assert False - async def closed_stream(self, network, stream): + async def closed_stream(self): assert False - async def connected(self, network, conn): + async def connected(self): assert False - async def disconnected(self, network, conn): + async def disconnected(self): assert False - async def listen(self, network, multiaddr): + async def listen(self): assert False async def perform_two_host_simple_set_up(): @@ -98,7 +98,7 @@ async def test_one_notifier(): # Add notifee for node_a events = [] - node_a.get_network().notify(MyNotifee(events, "0")) + assert node_a.get_network().notify(MyNotifee(events, "0")) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) @@ -122,8 +122,11 @@ async def test_one_notifier(): @pytest.mark.asyncio async def test_one_notifier_on_two_nodes(): events_b = [] - + async def my_stream_handler(stream): + # Ensure the connected and opened_stream events were hit in Notifee obj + # and that the stream passed into opened_stream matches the stream created on + # node_b assert events_b == [["connectedb", stream.mplex_conn], \ ["opened_streamb", stream]] while True: @@ -136,10 +139,10 @@ async def test_one_notifier_on_two_nodes(): # Add notifee for node_a events_a = [] - node_a.get_network().notify(MyNotifee(events_a, "a")) + assert node_a.get_network().notify(MyNotifee(events_a, "a")) # Add notifee for node_b - node_b.get_network().notify(MyNotifee(events_b, "b")) + assert node_b.get_network().notify(MyNotifee(events_b, "b")) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) @@ -166,10 +169,10 @@ async def test_two_notifiers(): # Add notifee for node_a events0 = [] - node_a.get_network().notify(MyNotifee(events0, "0")) + assert node_a.get_network().notify(MyNotifee(events0, "0")) events1 = [] - node_a.get_network().notify(MyNotifee(events1, "1")) + assert node_a.get_network().notify(MyNotifee(events1, "1")) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) @@ -201,7 +204,7 @@ async def test_ten_notifiers(): events_lst = [] for i in range(num_notifiers): events_lst.append([]) - node_a.get_network().notify(MyNotifee(events_lst[i], str(i))) + assert node_a.get_network().notify(MyNotifee(events_lst[i], str(i))) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) @@ -223,6 +226,54 @@ async def test_ten_notifiers(): # Success, terminate pending tasks. await cleanup() +@pytest.mark.asyncio +async def test_ten_notifiers_on_two_nodes(): + num_notifiers = 10 + events_lst_b = [] + + async def my_stream_handler(stream): + # Ensure the connected and opened_stream events were hit in all Notifee objs + # and that the stream passed into opened_stream matches the stream created on + # node_b + for i in range(num_notifiers): + assert events_lst_b[i] == [["connectedb" + str(i), stream.mplex_conn], \ + ["opened_streamb" + str(i), stream]] + while True: + read_string = (await stream.read()).decode() + + resp = "ack:" + read_string + await stream.write(resp.encode()) + + node_a, node_b = await perform_two_host_simple_set_up_custom_handler(my_stream_handler) + + # Add notifee for node_a and node_b + events_lst_a = [] + for i in range(num_notifiers): + events_lst_a.append([]) + events_lst_b.append([]) + assert node_a.get_network().notify(MyNotifee(events_lst_a[i], "a" + str(i))) + assert node_b.get_network().notify(MyNotifee(events_lst_b[i], "b" + str(i))) + + stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"]) + + # Ensure the connected and opened_stream events were hit in all Notifee objs + # and that the stream passed into opened_stream matches the stream created on + # node_a + for i in range(num_notifiers): + assert events_lst_a[i] == [["connecteda" + str(i), stream.mplex_conn], \ + ["opened_streama" + str(i), stream]] + + messages = ["hello", "hello"] + for message in messages: + await stream.write(message.encode()) + + response = (await stream.read()).decode() + + assert response == ("ack:" + message) + + # Success, terminate pending tasks. + await cleanup() + @pytest.mark.asyncio async def test_invalid_notifee(): num_notifiers = 10 @@ -231,9 +282,9 @@ async def test_invalid_notifee(): # Add notifee for node_a events_lst = [] - for i in range(num_notifiers): + for _ in range(num_notifiers): events_lst.append([]) - node_a.get_network().notify(InvalidNotifee()) + assert not node_a.get_network().notify(InvalidNotifee()) stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])