Fix: Change the event.close to event.set

And add missing parts.
This commit is contained in:
mhchia 2019-09-06 21:35:15 +08:00
parent 1cd969a2d5
commit 6c1f77dc1a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
7 changed files with 19 additions and 21 deletions

View File

@ -3,7 +3,6 @@ import logging
import time import time
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,

View File

@ -166,8 +166,11 @@ class Mplex(IMuxedConn):
if channel_id is not None and flag is not None and message is not None: if channel_id is not None and flag is not None and message is not None:
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
is_stream_id_seen: bool is_stream_id_seen: bool
stream: MplexStream
async with self.streams_lock: async with self.streams_lock:
is_stream_id_seen = stream_id in self.streams is_stream_id_seen = stream_id in self.streams
if is_stream_id_seen:
stream = self.streams[stream_id]
# Other consequent stream message should wait until the stream get accepted # Other consequent stream message should wait until the stream get accepted
# TODO: Handle more tags, and refactor `HeaderTags` # TODO: Handle more tags, and refactor `HeaderTags`
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
@ -185,8 +188,6 @@ class Mplex(IMuxedConn):
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
continue continue
async with self.streams_lock:
stream = self.streams[stream_id]
await stream.incoming_data.put(message) await stream.incoming_data.put(message)
elif flag in ( elif flag in (
HeaderTags.CloseInitiator.value, HeaderTags.CloseInitiator.value,
@ -194,15 +195,17 @@ class Mplex(IMuxedConn):
): ):
if not is_stream_id_seen: if not is_stream_id_seen:
continue continue
stream: MplexStream # NOTE: If remote is already closed, then return: Technically a bug
async with self.streams_lock: # on the other side. We should consider killing the connection.
stream = self.streams[stream_id] async with stream.close_lock:
if stream.event_remote_closed.is_set():
continue
is_local_closed: bool is_local_closed: bool
async with stream.close_lock: async with stream.close_lock:
stream.event_remote_closed.set() stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set() is_local_closed = stream.event_local_closed.is_set()
# If local is also closed, both sides are closed. Then, we should clean up # If local is also closed, both sides are closed. Then, we should clean up
# this stream. # the entry of this stream, to avoid others from accessing it.
if is_local_closed: if is_local_closed:
async with self.streams_lock: async with self.streams_lock:
del self.streams[stream_id] del self.streams[stream_id]
@ -213,24 +216,21 @@ class Mplex(IMuxedConn):
if not is_stream_id_seen: if not is_stream_id_seen:
# This is *ok*. We forget the stream on reset. # This is *ok*. We forget the stream on reset.
continue continue
stream: MplexStream
async with self.streams_lock:
stream = self.streams[stream_id]
async with stream.close_lock: async with stream.close_lock:
if not stream.event_remote_closed.is_set(): if not stream.event_remote_closed.is_set():
# TODO: Why? Only if remote is not closed before then reset.
stream.event_reset.set() stream.event_reset.set()
stream.event_remote_closed.set() stream.event_remote_closed.set()
# If local is not closed, we should close it.
if not stream.event_local_closed.is_set(): if not stream.event_local_closed.is_set():
stream.event_local_closed.close() stream.event_local_closed.set()
async with self.streams_lock: async with self.streams_lock:
del self.streams[stream_id] del self.streams[stream_id]
else: else:
# TODO: logging # TODO: logging
print(f"message with unknown header on stream {stream_id}")
if is_stream_id_seen: if is_stream_id_seen:
async with self.streams_lock: await stream.reset()
stream = self.streams[stream_id]
await stream.reset()
# Force context switch # Force context switch
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -4,8 +4,8 @@ from typing import TYPE_CHECKING
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.abc import IMuxedStream
from .constants import HeaderTags from .constants import HeaderTags
from .exceptions import MplexStreamReset, MplexStreamEOF
from .datastructures import StreamID from .datastructures import StreamID
from .exceptions import MplexStreamEOF, MplexStreamReset
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.stream_muxer.mplex.mplex import Mplex
@ -55,7 +55,7 @@ class MplexStream(IMuxedStream):
return self.stream_id.is_initiator return self.stream_id.is_initiator
async def _wait_for_data(self) -> None: async def _wait_for_data(self) -> None:
done, pending = await asyncio.wait( done, pending = await asyncio.wait( # type: ignore
[ [
self.event_reset.wait(), self.event_reset.wait(),
self.event_remote_closed.wait(), self.event_remote_closed.wait(),

View File

@ -2,8 +2,8 @@ import multiaddr
import pytest import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.utils import cleanup, set_up_nodes_by_transport_opt
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
from tests.utils import cleanup, set_up_nodes_by_transport_opt
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -16,8 +16,8 @@ from libp2p import initialize_default_swarm, new_node
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.notifee_interface import INotifee from libp2p.network.notifee_interface import INotifee
from tests.utils import cleanup, perform_two_host_set_up
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
from tests.utils import cleanup, perform_two_host_set_up
ACK = "ack:" ACK = "ack:"

View File

@ -1,7 +1,7 @@
import pytest import pytest
from libp2p.protocol_muxer.exceptions import MultiselectClientError from libp2p.protocol_muxer.exceptions import MultiselectClientError
from tests.utils import cleanup, set_up_nodes_by_transport_opt, echo_stream_handler from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt
# TODO: Add tests for multiple streams being opened on different # TODO: Add tests for multiple streams being opened on different
# protocols through the same connection # protocols through the same connection

View File

@ -5,7 +5,6 @@ import multiaddr
from libp2p import new_node from libp2p import new_node
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN