diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b26dd3c..7b01d4b 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -1,8 +1,9 @@ import logging -from typing import List, Sequence +from typing import TYPE_CHECKING, List, Sequence import multiaddr +from libp2p.host.defaults import DEFAULT_HOST_PROTOCOLS from libp2p.host.exceptions import StreamFailure from libp2p.network.network_interface import INetwork from libp2p.network.stream.net_stream_interface import INetStream @@ -17,6 +18,9 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .host_interface import IHost +if TYPE_CHECKING: + from collections import OrderedDict + # Upon host creation, host takes in options, # including the list of addresses on which to listen. # Host then parses these options and delegates to its Network instance, @@ -38,12 +42,16 @@ class BasicHost(IHost): multiselect: Multiselect multiselect_client: MultiselectClient - def __init__(self, network: INetwork) -> None: + def __init__( + self, + network: INetwork, + default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = DEFAULT_HOST_PROTOCOLS, + ) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) self.peerstore = self._network.peerstore # Protocol muxing - self.multiselect = Multiselect() + self.multiselect = Multiselect(default_protocols) self.multiselect_client = MultiselectClient() def get_id(self) -> ID: diff --git a/libp2p/host/defaults.py b/libp2p/host/defaults.py new file mode 100644 index 0000000..ce9b6d6 --- /dev/null +++ b/libp2p/host/defaults.py @@ -0,0 +1,7 @@ +from collections import OrderedDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libp2p.typing import TProtocol, StreamHandlerFn + +DEFAULT_HOST_PROTOCOLS: "OrderedDict[TProtocol, StreamHandlerFn]" = OrderedDict() diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 88f7e37..72279c6 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -19,8 +19,8 @@ class Multiselect(IMultiselectMuxer): handlers: Dict[TProtocol, StreamHandlerFn] - def __init__(self) -> None: - self.handlers = {} + def __init__(self, default_handlers: Dict[TProtocol, StreamHandlerFn] = {}) -> None: + self.handlers = default_handlers def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None: """