Do p2pd.close if not all of them succeed

This commit is contained in:
mhchia 2019-09-04 18:25:51 +08:00
parent 51d547ccc5
commit db0da8083a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 18 additions and 5 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import sys import sys
from typing import Union
import pexpect import pexpect
import pytest import pytest
@ -7,7 +8,7 @@ import pytest
from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
from tests.pubsub.configs import GOSSIPSUB_PARAMS from tests.pubsub.configs import GOSSIPSUB_PARAMS
from .daemon import make_p2pd from .daemon import Daemon, make_p2pd
@pytest.fixture @pytest.fixture
@ -42,7 +43,7 @@ def is_gossipsub():
@pytest.fixture @pytest.fixture
async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory): async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory):
p2pds = await asyncio.gather( p2pds: Union[Daemon, Exception] = await asyncio.gather(
*[ *[
make_p2pd( make_p2pd(
unused_tcp_port_factory(), unused_tcp_port_factory(),
@ -51,8 +52,15 @@ async def p2pds(num_p2pds, is_host_secure, is_gossipsub, unused_tcp_port_factory
is_gossipsub=is_gossipsub, is_gossipsub=is_gossipsub,
) )
for _ in range(num_p2pds) for _ in range(num_p2pds)
] ],
return_exceptions=True,
) )
p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon))
if len(p2pds_succeeded) != len(p2pds):
# Not all succeeded. Close the succeeded ones and print the failed ones(exceptions).
await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded])
exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception))
raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}")
try: try:
yield p2pds yield p2pds
finally: finally:

View File

@ -38,6 +38,7 @@ class P2PDProcess:
proc: asyncio.subprocess.Process proc: asyncio.subprocess.Process
cmd: str = str(P2PD_PATH) cmd: str = str(P2PD_PATH)
args: List[Any] args: List[Any]
is_running: bool
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
@ -70,6 +71,8 @@ class P2PDProcess:
# - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second # - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second
# Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501 # Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501
self.args = args self.args = args
self.is_running = False
self._tasks = [] self._tasks = []
async def wait_until_ready(self): async def wait_until_ready(self):
@ -117,8 +120,10 @@ class P2PDProcess:
await self.start_printing_logs() await self.start_printing_logs()
async def close(self) -> None: async def close(self) -> None:
if self.is_running:
self.proc.terminate() self.proc.terminate()
await self.proc.wait() await self.proc.wait()
self.is_running = False
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()