Add "closed" and "shutting_down" events

This commit is contained in:
mhchia 2019-09-12 17:07:41 +08:00
parent 7cf0495f37
commit 5653b3f604
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
3 changed files with 45 additions and 9 deletions

View File

@ -5,7 +5,11 @@ class MuxedConnError(BaseLibp2pError):
pass
class MuxedConnShutdown(MuxedConnError):
class MuxedConnShuttingDown(MuxedConnError):
pass
class MuxedConnClosed(MuxedConnError):
pass

View File

@ -1,6 +1,7 @@
from libp2p.stream_muxer.exceptions import (
MuxedConnError,
MuxedConnShutdown,
MuxedConnShuttingDown,
MuxedConnClosed,
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
@ -11,7 +12,11 @@ class MplexError(MuxedConnError):
pass
class MplexShutdown(MuxedConnShutdown):
class MplexShuttingDown(MuxedConnShuttingDown):
pass
class MplexClosed(MuxedConnClosed):
pass

View File

@ -14,6 +14,7 @@ from libp2p.utils import (
)
from .constants import HeaderTags
from .exceptions import MplexClosed, MplexShuttingDown
from .datastructures import StreamID
from .mplex_stream import MplexStream
@ -34,7 +35,8 @@ class Mplex(IMuxedConn):
streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event
event_shutting_down: asyncio.Event
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
@ -58,7 +60,8 @@ class Mplex(IMuxedConn):
self.streams = {}
self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event()
self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self._tasks = []
@ -73,16 +76,20 @@ class Mplex(IMuxedConn):
"""
close the stream muxer and underlying secured connection
"""
for task in self._tasks:
task.cancel()
# for task in self._tasks:
# task.cancel()
await self.secured_conn.close()
# Set the `event_shutting_down`, to allow graceful shutdown.
self.event_shutting_down.set()
# Blocked until `close` is finally set.
# await self.event_closed.wait()
def is_closed(self) -> bool:
"""
check connection is fully closed
:return: true if successful
"""
raise NotImplementedError()
return self.event_closed.is_set()
def _get_next_channel_id(self) -> int:
"""
@ -112,11 +119,31 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream
async def _wait_until_closed(self, coro) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED
)
if task_wait_closed in done:
raise MplexClosed
return task_coro.result()
async def _wait_until_shutting_down(self, coro) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED
)
if task_wait_shutting_down in done:
raise MplexShuttingDown
return task_coro.result()
async def accept_stream(self) -> IMuxedStream:
"""
accepts a muxed stream opened by the other end
"""
return await self.new_stream_queue.get()
return await self._wait_until_closed(self.new_stream_queue.get())
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID