diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 4ae9e4a..87b039f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -55,43 +55,45 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: + task_event_reset = asyncio.ensure_future(self.event_reset.wait()) + task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) + task_event_remote_closed = asyncio.ensure_future( + self.event_remote_closed.wait() + ) done, pending = await asyncio.wait( # type: ignore - [ - self.event_reset.wait(), - self.incoming_data.get(), - self.event_remote_closed.wait(), - ], + [task_event_reset, task_incoming_data_get, task_event_remote_closed], return_when=asyncio.FIRST_COMPLETED, ) for fut in pending: fut.cancel() - if self.event_reset.is_set(): - raise MplexStreamReset + if task_event_reset in done: + if self.event_reset.is_set(): + raise MplexStreamReset + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) - if len(done) != 1: - raise Exception(f"Should be exactly 1 job in {done}.") - done_task = tuple(done)[0] - # NOTE: Ignore type check because the typeshed for `asyncio.Task` does not - # have the field `_coro`. - coro_qualname = done_task._coro.__qualname__ # type: ignore - # If `qualname == "Queue.get"` then there is incoming data. We can add it to the buffer. - if coro_qualname == "Queue.get": - data = done_task.result() + if task_incoming_data_get in done: + data = task_incoming_data_get.result() self._buf.extend(data) return - if self.event_remote_closed.is_set(): - raise MplexStreamEOF + if task_event_remote_closed in done: + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) - # If the task is not `Queue.get`, then it must be `Event.wait`. - # However, it is abnormal that `Event.wait` is unblocked without any of the event - # (remote_closed and reset) is set. Then it is highly possible that the task - # is cancelled. - raise Exception( - "Should not enter here. " - f"It is highly possible that `done_task` is cancelled. `done_task`={done_task}" - ) # TODO: Handle timeout when deadline is used. async def _read_until_eof(self) -> bytes: