diff --git a/tests/interop/conftest.py b/tests/interop/conftest.py index 76b11d6..7261ee7 100644 --- a/tests/interop/conftest.py +++ b/tests/interop/conftest.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Union import pexpect import pytest @@ -7,7 +8,7 @@ import pytest from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.pubsub.configs import GOSSIPSUB_PARAMS -from .daemon import make_p2pd +from .daemon import Daemon, make_p2pd @pytest.fixture @@ -42,7 +43,7 @@ def is_gossipsub(): @pytest.fixture async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory): - p2pds = await asyncio.gather( + p2pds: Union[Daemon, Exception] = await asyncio.gather( *[ make_p2pd( unused_tcp_port_factory(), @@ -51,8 +52,15 @@ async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory is_gossipsub=is_gossipsub, ) for _ in range(num_p2pds) - ] + ], + return_exceptions=True, ) + p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon)) + if len(p2pds_succeeded) != len(p2pds): + # Not all succeeded. Close the succeeded ones and print the failed ones(exceptions). + await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded]) + exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception)) + raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}") try: yield p2pds finally: diff --git a/tests/interop/daemon.py b/tests/interop/daemon.py index 607ec3a..754b563 100644 --- a/tests/interop/daemon.py +++ b/tests/interop/daemon.py @@ -38,6 +38,7 @@ class P2PDProcess: proc: asyncio.subprocess.Process cmd: str = str(P2PD_PATH) args: List[Any] + is_running: bool _tasks: List["asyncio.Future[Any]"] @@ -70,6 +71,8 @@ class P2PDProcess: # - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second # Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501 self.args = args + self.is_running = False + self._tasks = [] async def wait_until_ready(self): @@ -117,8 +120,10 @@ class P2PDProcess: await self.start_printing_logs() async def close(self) -> None: - self.proc.terminate() - await self.proc.wait() + if self.is_running: + self.proc.terminate() + await self.proc.wait() + self.is_running = False for task in self._tasks: task.cancel()