Add "closed" and "shutting_down" events

This commit is contained in:
mhchia 2019-09-12 17:07:41 +08:00
parent 7cf0495f37
commit 5653b3f604
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
3 changed files with 45 additions and 9 deletions

View File

@ -5,7 +5,11 @@ class MuxedConnError(BaseLibp2pError):
pass pass
class MuxedConnShutdown(MuxedConnError): class MuxedConnShuttingDown(MuxedConnError):
pass
class MuxedConnClosed(MuxedConnError):
pass pass

View File

@ -1,6 +1,7 @@
from libp2p.stream_muxer.exceptions import ( from libp2p.stream_muxer.exceptions import (
MuxedConnError, MuxedConnError,
MuxedConnShutdown, MuxedConnShuttingDown,
MuxedConnClosed,
MuxedStreamClosed, MuxedStreamClosed,
MuxedStreamEOF, MuxedStreamEOF,
MuxedStreamReset, MuxedStreamReset,
@ -11,7 +12,11 @@ class MplexError(MuxedConnError):
pass pass
class MplexShutdown(MuxedConnShutdown): class MplexShuttingDown(MuxedConnShuttingDown):
pass
class MplexClosed(MuxedConnClosed):
pass pass

View File

@ -14,6 +14,7 @@ from libp2p.utils import (
) )
from .constants import HeaderTags from .constants import HeaderTags
from .exceptions import MplexClosed, MplexShuttingDown
from .datastructures import StreamID from .datastructures import StreamID
from .mplex_stream import MplexStream from .mplex_stream import MplexStream
@ -34,7 +35,8 @@ class Mplex(IMuxedConn):
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]" new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event event_shutting_down: asyncio.Event
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
@ -58,7 +60,8 @@ class Mplex(IMuxedConn):
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue() self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event() self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self._tasks = [] self._tasks = []
@ -73,16 +76,20 @@ class Mplex(IMuxedConn):
""" """
close the stream muxer and underlying secured connection close the stream muxer and underlying secured connection
""" """
for task in self._tasks: # for task in self._tasks:
task.cancel() # task.cancel()
await self.secured_conn.close() await self.secured_conn.close()
# Set the `event_shutting_down`, to allow graceful shutdown.
self.event_shutting_down.set()
# Blocked until `close` is finally set.
# await self.event_closed.wait()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
check connection is fully closed check connection is fully closed
:return: true if successful :return: true if successful
""" """
raise NotImplementedError() return self.event_closed.is_set()
def _get_next_channel_id(self) -> int: def _get_next_channel_id(self) -> int:
""" """
@ -112,11 +119,31 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def _wait_until_closed(self, coro) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED
)
if task_wait_closed in done:
raise MplexClosed
return task_coro.result()
async def _wait_until_shutting_down(self, coro) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED
)
if task_wait_shutting_down in done:
raise MplexShuttingDown
return task_coro.result()
async def accept_stream(self) -> IMuxedStream: async def accept_stream(self) -> IMuxedStream:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
return await self.new_stream_queue.get() return await self._wait_until_closed(self.new_stream_queue.get())
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID