diff --git a/examples/chat/chat.py b/examples/chat/chat.py index e8726b1..aebe094 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -6,16 +6,18 @@ import urllib.request import multiaddr from libp2p import new_node +from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.typing import TProtocol -PROTOCOL_ID = "/chat/1.0.0" +PROTOCOL_ID = TProtocol("/chat/1.0.0") -async def read_data(stream): +async def read_data(stream: INetStream) -> None: while True: - read_string = await stream.read() - if read_string is not None: - read_string = read_string.decode() + read_bytes = await stream.read() + if read_bytes is not None: + read_string = read_bytes.decode() if read_string != "\n": # Green console colour: \x1b[32m # Reset console colour: \x1b[0m @@ -23,14 +25,14 @@ async def read_data(stream): # FIXME(mhchia): Reconsider whether we should use a thread pool here. -async def write_data(stream): +async def write_data(stream: INetStream) -> None: loop = asyncio.get_event_loop() while True: line = await loop.run_in_executor(None, sys.stdin.readline) await stream.write(line.encode()) -async def run(port, destination, localhost): +async def run(port: int, destination: str, localhost: bool) -> None: if localhost: ip = "127.0.0.1" else: @@ -42,7 +44,7 @@ async def run(port, destination, localhost): if not destination: # its the server - async def stream_handler(stream): + async def stream_handler(stream: INetStream) -> None: asyncio.ensure_future(read_data(stream)) asyncio.ensure_future(write_data(stream)) @@ -73,7 +75,7 @@ async def run(port, destination, localhost): print("Connected to peer %s" % info.addrs[0]) -def main(): +def main() -> None: description = """ This program demonstrates a simple p2p chat application using libp2p. To use it, first run 'python ./chat -p ', where is the port number. diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 5c14224..a068468 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,7 +1,14 @@ import asyncio +from typing import Mapping, Sequence from Crypto.PublicKey import RSA +from libp2p.kademlia.storage import IStorage +from libp2p.network.network_interface import INetwork +from libp2p.peer.peerstore_interface import IPeerStore +from libp2p.routing.interfaces import IPeerRouting +from libp2p.security.secure_transport_interface import ISecureTransport + from .host.basic_host import BasicHost from .kademlia.network import KademliaServer from .network.swarm import Swarm @@ -11,9 +18,10 @@ from .routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from .security.insecure_security import InsecureTransport from .transport.tcp.tcp import TCP from .transport.upgrader import TransportUpgrader +from .typing import TProtocol -async def cleanup_done_tasks(): +async def cleanup_done_tasks() -> None: """ clean up asyncio done tasks to free up resources """ @@ -27,13 +35,15 @@ async def cleanup_done_tasks(): await asyncio.sleep(3) -def generate_id(): +def generate_id() -> ID: new_key = RSA.generate(2048, e=65537).publickey().export_key("DER") new_id = ID.from_pubkey(new_key) return new_id -def initialize_default_kademlia_router(ksize=20, alpha=3, id_opt=None, storage=None): +def initialize_default_kademlia_router( + ksize: int = 20, alpha: int = 3, id_opt: ID = None, storage: IStorage = None +) -> KadmeliaPeerRouter: """ initialize kadmelia router when no kademlia router is passed in :param ksize: The k parameter from the paper @@ -47,13 +57,21 @@ def initialize_default_kademlia_router(ksize=20, alpha=3, id_opt=None, storage=N id_opt = generate_id() node_id = id_opt.to_bytes() - server = KademliaServer(ksize=ksize, alpha=alpha, node_id=node_id, storage=storage) + # ignore type for Kademlia module + server = KademliaServer( # type: ignore + ksize=ksize, alpha=alpha, node_id=node_id, storage=storage + ) return KadmeliaPeerRouter(server) def initialize_default_swarm( - id_opt=None, transport_opt=None, muxer_opt=None, sec_opt=None, peerstore_opt=None, disc_opt=None -): + id_opt: ID = None, + transport_opt: Sequence[str] = None, + muxer_opt: Sequence[str] = None, + sec_opt: Mapping[TProtocol, ISecureTransport] = None, + peerstore_opt: IPeerStore = None, + disc_opt: IPeerRouting = None, +) -> Swarm: """ initialize swarm when no swarm is passed in :param id_opt: optional id for host @@ -75,7 +93,7 @@ def initialize_default_swarm( # TODO TransportUpgrader is not doing anything really # TODO parse muxer and sec to pass into TransportUpgrader muxer = muxer_opt or ["mplex/6.7.0"] - sec = sec_opt or {"insecure/1.0.0": InsecureTransport("insecure")} + sec = sec_opt or {TProtocol("insecure/1.0.0"): InsecureTransport("insecure")} upgrader = TransportUpgrader(sec, muxer) peerstore = peerstore_opt or PeerStore() @@ -86,14 +104,14 @@ def initialize_default_swarm( async def new_node( - swarm_opt=None, - id_opt=None, - transport_opt=None, - muxer_opt=None, - sec_opt=None, - peerstore_opt=None, - disc_opt=None, -): + swarm_opt: INetwork = None, + id_opt: ID = None, + transport_opt: Sequence[str] = None, + muxer_opt: Sequence[str] = None, + sec_opt: Mapping[TProtocol, ISecureTransport] = None, + peerstore_opt: IPeerStore = None, + disc_opt: IPeerRouting = None, +) -> BasicHost: """ create new libp2p node :param swarm_opt: optional swarm diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 661d83e..9419c98 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -8,7 +8,7 @@ from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerstore_interface import IPeerStore from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter -from libp2p.typing import StreamHandlerFn +from libp2p.typing import StreamHandlerFn, TProtocol from .host_interface import IHost @@ -66,7 +66,7 @@ class BasicHost(IHost): addrs.append(addr.encapsulate(p2p_part)) return addrs - def set_stream_handler(self, protocol_id: str, stream_handler: StreamHandlerFn) -> bool: + def set_stream_handler(self, protocol_id: TProtocol, stream_handler: StreamHandlerFn) -> bool: """ set stream handler for host :param protocol_id: protocol id used on stream @@ -77,7 +77,7 @@ class BasicHost(IHost): # protocol_id can be a list of protocol_ids # stream will decide which protocol_id to run on - async def new_stream(self, peer_id: ID, protocol_ids: Sequence[str]) -> INetStream: + async def new_stream(self, peer_id: ID, protocol_ids: Sequence[TProtocol]) -> INetStream: """ :param peer_id: peer_id that host is connecting :param protocol_id: protocol id that stream runs on diff --git a/libp2p/host/host_interface.py b/libp2p/host/host_interface.py index ce31dbc..667a437 100644 --- a/libp2p/host/host_interface.py +++ b/libp2p/host/host_interface.py @@ -7,7 +7,7 @@ from libp2p.network.network_interface import INetwork from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from libp2p.peer.peerinfo import PeerInfo -from libp2p.typing import StreamHandlerFn +from libp2p.typing import StreamHandlerFn, TProtocol class IHost(ABC): @@ -37,7 +37,7 @@ class IHost(ABC): """ @abstractmethod - def set_stream_handler(self, protocol_id: str, stream_handler: StreamHandlerFn) -> bool: + def set_stream_handler(self, protocol_id: TProtocol, stream_handler: StreamHandlerFn) -> bool: """ set stream handler for host :param protocol_id: protocol id used on stream @@ -48,7 +48,7 @@ class IHost(ABC): # protocol_id can be a list of protocol_ids # stream will decide which protocol_id to run on @abstractmethod - async def new_stream(self, peer_id: ID, protocol_ids: Sequence[str]) -> INetStream: + async def new_stream(self, peer_id: ID, protocol_ids: Sequence[TProtocol]) -> INetStream: """ :param peer_id: peer_id that host is connecting :param protocol_ids: protocol ids that stream can run on diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index 468adfd..f1b53bc 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Dict, Sequence from multiaddr import Multiaddr from libp2p.peer.id import ID -from libp2p.peer.peerstore import PeerStore +from libp2p.peer.peerstore_interface import IPeerStore from libp2p.stream_muxer.abc import IMuxedConn from libp2p.transport.listener_interface import IListener -from libp2p.typing import StreamHandlerFn +from libp2p.typing import StreamHandlerFn, TProtocol from .stream.net_stream_interface import INetStream @@ -17,7 +17,7 @@ if TYPE_CHECKING: class INetwork(ABC): - peerstore: PeerStore + peerstore: IPeerStore connections: Dict[ID, IMuxedConn] listeners: Dict[str, IListener] @@ -38,7 +38,7 @@ class INetwork(ABC): """ @abstractmethod - def set_stream_handler(self, protocol_id: str, stream_handler: StreamHandlerFn) -> bool: + def set_stream_handler(self, protocol_id: TProtocol, stream_handler: StreamHandlerFn) -> bool: """ :param protocol_id: protocol id used on stream :param stream_handler: a stream handler instance @@ -46,7 +46,7 @@ class INetwork(ABC): """ @abstractmethod - async def new_stream(self, peer_id: ID, protocol_ids: Sequence[str]) -> INetStream: + async def new_stream(self, peer_id: ID, protocol_ids: Sequence[TProtocol]) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_ids: available protocol ids to use for stream diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index c5a7c2e..f4e078e 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,4 +1,5 @@ from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.typing import TProtocol from .net_stream_interface import INetStream @@ -7,20 +8,20 @@ class NetStream(INetStream): muxed_stream: IMuxedStream mplex_conn: IMuxedConn - protocol_id: str + protocol_id: TProtocol def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream self.mplex_conn = muxed_stream.mplex_conn self.protocol_id = None - def get_protocol(self) -> str: + def get_protocol(self) -> TProtocol: """ :return: protocol id that stream runs on """ return self.protocol_id - def set_protocol(self, protocol_id: str) -> None: + def set_protocol(self, protocol_id: TProtocol) -> None: """ :param protocol_id: protocol id that stream runs on :return: true if successful diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index d3ac2ff..6bf25ea 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from libp2p.stream_muxer.abc import IMuxedConn +from libp2p.typing import TProtocol class INetStream(ABC): @@ -8,13 +9,13 @@ class INetStream(ABC): mplex_conn: IMuxedConn @abstractmethod - def get_protocol(self) -> str: + def get_protocol(self) -> TProtocol: """ :return: protocol id that stream runs on """ @abstractmethod - def set_protocol(self, protocol_id: str) -> bool: + def set_protocol(self, protocol_id: TProtocol) -> bool: """ :param protocol_id: protocol id that stream runs on :return: true if successful diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 26cdc34..2fbc002 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, List, Sequence from multiaddr import Multiaddr from libp2p.peer.id import ID -from libp2p.peer.peerstore import PeerStore +from libp2p.peer.peerstore_interface import IPeerStore from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.routing.interfaces import IPeerRouting @@ -12,7 +12,7 @@ from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.upgrader import TransportUpgrader -from libp2p.typing import StreamHandlerFn +from libp2p.typing import StreamHandlerFn, TProtocol from .connection.raw_connection import RawConnection from .network_interface import INetwork @@ -25,7 +25,7 @@ from .typing import GenericProtocolHandlerFn class Swarm(INetwork): self_id: ID - peerstore: PeerStore + peerstore: IPeerStore upgrader: TransportUpgrader transport: ITransport router: IPeerRouting @@ -41,7 +41,7 @@ class Swarm(INetwork): def __init__( self, peer_id: ID, - peerstore: PeerStore, + peerstore: IPeerStore, upgrader: TransportUpgrader, transport: ITransport, router: IPeerRouting, @@ -68,7 +68,7 @@ class Swarm(INetwork): def get_peer_id(self) -> ID: return self.self_id - def set_stream_handler(self, protocol_id: str, stream_handler: StreamHandlerFn) -> bool: + def set_stream_handler(self, protocol_id: TProtocol, stream_handler: StreamHandlerFn) -> bool: """ :param protocol_id: protocol id used on stream :param stream_handler: a stream handler instance @@ -121,7 +121,7 @@ class Swarm(INetwork): return muxed_conn - async def new_stream(self, peer_id: ID, protocol_ids: Sequence[str]) -> NetStream: + async def new_stream(self, peer_id: ID, protocol_ids: Sequence[TProtocol]) -> NetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -157,7 +157,7 @@ class Swarm(INetwork): return net_stream - async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool: + async def listen(self, *multiaddrs: Multiaddr) -> bool: """ :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 5772ab0..c9b0010 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -1,6 +1,7 @@ from typing import Iterable, List, Sequence from libp2p.peer.id import ID +from libp2p.typing import TProtocol from .pb import rpc_pb2 from .pubsub import Pubsub @@ -9,15 +10,15 @@ from .pubsub_router_interface import IPubsubRouter class FloodSub(IPubsubRouter): - protocols: List[str] + protocols: List[TProtocol] pubsub: Pubsub - def __init__(self, protocols: Sequence[str]) -> None: + def __init__(self, protocols: Sequence[TProtocol]) -> None: self.protocols = list(protocols) self.pubsub = None - def get_protocols(self) -> List[str]: + def get_protocols(self) -> List[TProtocol]: """ :return: the list of protocols supported by the router """ @@ -31,7 +32,7 @@ class FloodSub(IPubsubRouter): """ self.pubsub = pubsub - def add_peer(self, peer_id: ID, protocol_id: str) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected :param peer_id: id of peer to add @@ -43,7 +44,7 @@ class FloodSub(IPubsubRouter): :param peer_id: id of peer to remove """ - async def handle_rpc(self, rpc: rpc_pb2.ControlMessage, sender_peer_id: ID) -> None: + async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None: """ Invoked to process control messages in the RPC envelope. It is invoked after subscriptions and payload messages have been processed diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index e68e6e5..caaf58c 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -4,6 +4,7 @@ import random from typing import Any, Dict, Iterable, List, Sequence, Set from libp2p.peer.id import ID +from libp2p.typing import TProtocol from .mcache import MessageCache from .pb import rpc_pb2 @@ -13,7 +14,7 @@ from .pubsub_router_interface import IPubsubRouter class GossipSub(IPubsubRouter): - protocols: List[str] + protocols: List[TProtocol] pubsub: Pubsub degree: int @@ -38,7 +39,7 @@ class GossipSub(IPubsubRouter): def __init__( self, - protocols: Sequence[str], + protocols: Sequence[TProtocol], degree: int, degree_low: int, degree_high: int, @@ -79,7 +80,7 @@ class GossipSub(IPubsubRouter): # Interface functions - def get_protocols(self) -> List[str]: + def get_protocols(self) -> List[TProtocol]: """ :return: the list of protocols supported by the router """ @@ -97,7 +98,7 @@ class GossipSub(IPubsubRouter): # TODO: Start after delay asyncio.ensure_future(self.heartbeat()) - def add_peer(self, peer_id: ID, protocol_id: str) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected :param peer_id: id of peer to add @@ -126,7 +127,7 @@ class GossipSub(IPubsubRouter): if peer_id in self.peers_gossipsub: self.peers_floodsub.remove(peer_id) - async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: ID) -> None: + async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None: """ Invoked to process control messages in the RPC envelope. It is invoked after subscriptions and payload messages have been processed @@ -436,7 +437,7 @@ class GossipSub(IPubsubRouter): # RPC handlers - async def handle_ihave(self, ihave_msg: rpc_pb2.Message, sender_peer_id: ID) -> None: + async def handle_ihave(self, ihave_msg: rpc_pb2.ControlIHave, sender_peer_id: ID) -> None: """ Checks the seen set and requests unknown messages with an IWANT message. """ @@ -460,7 +461,7 @@ class GossipSub(IPubsubRouter): if msg_ids_wanted: await self.emit_iwant(msg_ids_wanted, sender_peer_id) - async def handle_iwant(self, iwant_msg: rpc_pb2.Message, sender_peer_id: ID) -> None: + async def handle_iwant(self, iwant_msg: rpc_pb2.ControlIWant, sender_peer_id: ID) -> None: """ Forwards all request messages that are present in mcache to the requesting peer. """ @@ -495,7 +496,7 @@ class GossipSub(IPubsubRouter): # 4) And write the packet to the stream await peer_stream.write(rpc_msg) - async def handle_graft(self, graft_msg: rpc_pb2.Message, sender_peer_id: ID) -> None: + async def handle_graft(self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID) -> None: topic: str = graft_msg.topicID # Add peer to mesh for topic @@ -506,7 +507,7 @@ class GossipSub(IPubsubRouter): # Respond with PRUNE if not subscribed to the topic await self.emit_prune(topic, sender_peer_id) - async def handle_prune(self, prune_msg: rpc_pb2.Message, sender_peer_id: ID) -> None: + async def handle_prune(self, prune_msg: rpc_pb2.ControlPrune, sender_peer_id: ID) -> None: topic: str = prune_msg.topicID # Remove peer from mesh for topic, if peer is in topic diff --git a/libp2p/pubsub/pb/rpc_pb2.pyi b/libp2p/pubsub/pb/rpc_pb2.pyi new file mode 100644 index 0000000..75c7c1e --- /dev/null +++ b/libp2p/pubsub/pb/rpc_pb2.pyi @@ -0,0 +1,322 @@ +# @generated by generate_proto_mypy_stubs.py. Do not edit! +import sys +from google.protobuf.descriptor import ( + Descriptor as google___protobuf___descriptor___Descriptor, + EnumDescriptor as google___protobuf___descriptor___EnumDescriptor, +) + +from google.protobuf.internal.containers import ( + RepeatedCompositeFieldContainer as google___protobuf___internal___containers___RepeatedCompositeFieldContainer, + RepeatedScalarFieldContainer as google___protobuf___internal___containers___RepeatedScalarFieldContainer, +) + +from google.protobuf.message import ( + Message as google___protobuf___message___Message, +) + +from typing import ( + Iterable as typing___Iterable, + List as typing___List, + Optional as typing___Optional, + Text as typing___Text, + Tuple as typing___Tuple, + cast as typing___cast, +) + +from typing_extensions import ( + Literal as typing_extensions___Literal, +) + + +class RPC(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + class SubOpts(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + subscribe = ... # type: bool + topicid = ... # type: typing___Text + + def __init__(self, + *, + subscribe : typing___Optional[bool] = None, + topicid : typing___Optional[typing___Text] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> RPC.SubOpts: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"subscribe",u"topicid"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"subscribe",u"topicid"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"subscribe",b"subscribe",u"topicid",b"topicid"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"subscribe",b"subscribe",u"topicid",b"topicid"]) -> None: ... + + + @property + def subscriptions(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[RPC.SubOpts]: ... + + @property + def publish(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[Message]: ... + + @property + def control(self) -> ControlMessage: ... + + def __init__(self, + *, + subscriptions : typing___Optional[typing___Iterable[RPC.SubOpts]] = None, + publish : typing___Optional[typing___Iterable[Message]] = None, + control : typing___Optional[ControlMessage] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> RPC: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"control"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"control",u"publish",u"subscriptions"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"control",b"control"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"control",b"control",u"publish",b"publish",u"subscriptions",b"subscriptions"]) -> None: ... + +class Message(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + from_id = ... # type: bytes + data = ... # type: bytes + seqno = ... # type: bytes + topicIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] + signature = ... # type: bytes + key = ... # type: bytes + + def __init__(self, + *, + from_id : typing___Optional[bytes] = None, + data : typing___Optional[bytes] = None, + seqno : typing___Optional[bytes] = None, + topicIDs : typing___Optional[typing___Iterable[typing___Text]] = None, + signature : typing___Optional[bytes] = None, + key : typing___Optional[bytes] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> Message: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"data",u"from_id",u"key",u"seqno",u"signature"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"data",u"from_id",u"key",u"seqno",u"signature",u"topicIDs"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"data",b"data",u"from_id",b"from_id",u"key",b"key",u"seqno",b"seqno",u"signature",b"signature"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"data",b"data",u"from_id",b"from_id",u"key",b"key",u"seqno",b"seqno",u"signature",b"signature",u"topicIDs",b"topicIDs"]) -> None: ... + +class ControlMessage(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + + @property + def ihave(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlIHave]: ... + + @property + def iwant(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlIWant]: ... + + @property + def graft(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlGraft]: ... + + @property + def prune(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[ControlPrune]: ... + + def __init__(self, + *, + ihave : typing___Optional[typing___Iterable[ControlIHave]] = None, + iwant : typing___Optional[typing___Iterable[ControlIWant]] = None, + graft : typing___Optional[typing___Iterable[ControlGraft]] = None, + prune : typing___Optional[typing___Iterable[ControlPrune]] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> ControlMessage: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def ClearField(self, field_name: typing_extensions___Literal[u"graft",u"ihave",u"iwant",u"prune"]) -> None: ... + else: + def ClearField(self, field_name: typing_extensions___Literal[u"graft",b"graft",u"ihave",b"ihave",u"iwant",b"iwant",u"prune",b"prune"]) -> None: ... + +class ControlIHave(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + topicID = ... # type: typing___Text + messageIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] + + def __init__(self, + *, + topicID : typing___Optional[typing___Text] = None, + messageIDs : typing___Optional[typing___Iterable[typing___Text]] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> ControlIHave: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",u"topicID"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",b"messageIDs",u"topicID",b"topicID"]) -> None: ... + +class ControlIWant(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + messageIDs = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] + + def __init__(self, + *, + messageIDs : typing___Optional[typing___Iterable[typing___Text]] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> ControlIWant: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs"]) -> None: ... + else: + def ClearField(self, field_name: typing_extensions___Literal[u"messageIDs",b"messageIDs"]) -> None: ... + +class ControlGraft(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + topicID = ... # type: typing___Text + + def __init__(self, + *, + topicID : typing___Optional[typing___Text] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> ControlGraft: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"topicID"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> None: ... + +class ControlPrune(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + topicID = ... # type: typing___Text + + def __init__(self, + *, + topicID : typing___Optional[typing___Text] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> ControlPrune: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"topicID"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"topicID",b"topicID"]) -> None: ... + +class TopicDescriptor(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + class AuthOpts(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + class AuthMode(int): + DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... + @classmethod + def Name(cls, number: int) -> str: ... + @classmethod + def Value(cls, name: str) -> TopicDescriptor.AuthOpts.AuthMode: ... + @classmethod + def keys(cls) -> typing___List[str]: ... + @classmethod + def values(cls) -> typing___List[TopicDescriptor.AuthOpts.AuthMode]: ... + @classmethod + def items(cls) -> typing___List[typing___Tuple[str, TopicDescriptor.AuthOpts.AuthMode]]: ... + NONE = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 0) + KEY = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 1) + WOT = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 2) + NONE = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 0) + KEY = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 1) + WOT = typing___cast(TopicDescriptor.AuthOpts.AuthMode, 2) + + mode = ... # type: TopicDescriptor.AuthOpts.AuthMode + keys = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[bytes] + + def __init__(self, + *, + mode : typing___Optional[TopicDescriptor.AuthOpts.AuthMode] = None, + keys : typing___Optional[typing___Iterable[bytes]] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> TopicDescriptor.AuthOpts: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"mode"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"keys",u"mode"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"mode",b"mode"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"keys",b"keys",u"mode",b"mode"]) -> None: ... + + class EncOpts(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + class EncMode(int): + DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... + @classmethod + def Name(cls, number: int) -> str: ... + @classmethod + def Value(cls, name: str) -> TopicDescriptor.EncOpts.EncMode: ... + @classmethod + def keys(cls) -> typing___List[str]: ... + @classmethod + def values(cls) -> typing___List[TopicDescriptor.EncOpts.EncMode]: ... + @classmethod + def items(cls) -> typing___List[typing___Tuple[str, TopicDescriptor.EncOpts.EncMode]]: ... + NONE = typing___cast(TopicDescriptor.EncOpts.EncMode, 0) + SHAREDKEY = typing___cast(TopicDescriptor.EncOpts.EncMode, 1) + WOT = typing___cast(TopicDescriptor.EncOpts.EncMode, 2) + NONE = typing___cast(TopicDescriptor.EncOpts.EncMode, 0) + SHAREDKEY = typing___cast(TopicDescriptor.EncOpts.EncMode, 1) + WOT = typing___cast(TopicDescriptor.EncOpts.EncMode, 2) + + mode = ... # type: TopicDescriptor.EncOpts.EncMode + keyHashes = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[bytes] + + def __init__(self, + *, + mode : typing___Optional[TopicDescriptor.EncOpts.EncMode] = None, + keyHashes : typing___Optional[typing___Iterable[bytes]] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> TopicDescriptor.EncOpts: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"mode"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"keyHashes",u"mode"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"mode",b"mode"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"keyHashes",b"keyHashes",u"mode",b"mode"]) -> None: ... + + name = ... # type: typing___Text + + @property + def auth(self) -> TopicDescriptor.AuthOpts: ... + + @property + def enc(self) -> TopicDescriptor.EncOpts: ... + + def __init__(self, + *, + name : typing___Optional[typing___Text] = None, + auth : typing___Optional[TopicDescriptor.AuthOpts] = None, + enc : typing___Optional[TopicDescriptor.EncOpts] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> TopicDescriptor: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def HasField(self, field_name: typing_extensions___Literal[u"auth",u"enc",u"name"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"auth",u"enc",u"name"]) -> None: ... + else: + def HasField(self, field_name: typing_extensions___Literal[u"auth",b"auth",u"enc",b"enc",u"name",b"name"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"auth",b"auth",u"enc",b"enc",u"name",b"name"]) -> None: ... diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index e715dba..6600c25 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,7 +1,18 @@ import asyncio import logging import time -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTuple, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + NamedTuple, + Tuple, + Union, + cast, +) from lru import LRU @@ -9,6 +20,7 @@ from libp2p.exceptions import ValidationError from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID +from libp2p.typing import TProtocol from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee @@ -31,7 +43,9 @@ AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] -TopicValidator = NamedTuple("TopicValidator", (("validator", ValidatorFn), ("is_async", bool))) +class TopicValidator(NamedTuple): + validator: ValidatorFn + is_async: bool class Pubsub: @@ -43,7 +57,7 @@ class Pubsub: peer_queue: "asyncio.Queue[ID]" - protocols: List[str] + protocols: List[TProtocol] incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]" outgoing_messages: "asyncio.Queue[rpc_pb2.Message]" @@ -192,7 +206,7 @@ class Pubsub: Get all validators corresponding to the topics in the message. :param msg: the message published to the topic """ - return ( + return tuple( self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators ) @@ -301,9 +315,7 @@ class Pubsub: # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() - packet.subscriptions.extend( - [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id.encode("utf-8"))] - ) + packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]) # Send out subscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -328,9 +340,7 @@ class Pubsub: # Create unsubscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() - packet.subscriptions.extend( - [rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id.encode("utf-8"))] - ) + packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)]) # Send out unsubscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -374,12 +384,14 @@ class Pubsub: :param msg: the message. """ sync_topic_validators = [] - async_topic_validator_futures = [] + async_topic_validator_futures: List[Awaitable[bool]] = [] for topic_validator in self.get_msg_validators(msg): if topic_validator.is_async: - async_topic_validator_futures.append(topic_validator.validator(msg_forwarder, msg)) + async_topic_validator_futures.append( + cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) + ) else: - sync_topic_validators.append(topic_validator.validator) + sync_topic_validators.append(cast(SyncValidatorFn, topic_validator.validator)) for validator in sync_topic_validators: if not validator(msg_forwarder, msg): diff --git a/libp2p/pubsub/pubsub_router_interface.py b/libp2p/pubsub/pubsub_router_interface.py index f2c01a3..5534c29 100644 --- a/libp2p/pubsub/pubsub_router_interface.py +++ b/libp2p/pubsub/pubsub_router_interface.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List from libp2p.peer.id import ID +from libp2p.typing import TProtocol from .pb import rpc_pb2 @@ -11,7 +12,7 @@ if TYPE_CHECKING: class IPubsubRouter(ABC): @abstractmethod - def get_protocols(self) -> List[str]: + def get_protocols(self) -> List[TProtocol]: """ :return: the list of protocols supported by the router """ @@ -25,7 +26,7 @@ class IPubsubRouter(ABC): """ @abstractmethod - def add_peer(self, peer_id: ID, protocol_id: str) -> None: + def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected :param peer_id: id of peer to add @@ -39,7 +40,7 @@ class IPubsubRouter(ABC): """ @abstractmethod - async def handle_rpc(self, rpc: rpc_pb2.ControlMessage, sender_peer_id: ID) -> None: + async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None: """ Invoked to process control messages in the RPC envelope. It is invoked after subscriptions and payload messages have been processed diff --git a/libp2p/routing/kademlia/kademlia_peer_router.py b/libp2p/routing/kademlia/kademlia_peer_router.py index 061bdda..59eaa1e 100644 --- a/libp2p/routing/kademlia/kademlia_peer_router.py +++ b/libp2p/routing/kademlia/kademlia_peer_router.py @@ -22,7 +22,8 @@ class KadmeliaPeerRouter(IPeerRouting): """ # switching peer_id to xor_id used by kademlia as node_id xor_id = peer_id.xor_id - value = await self.server.get(xor_id) + # ignore type for kad + value = await self.server.get(xor_id) # type: ignore return decode_peerinfo(value) @@ -36,5 +37,6 @@ def decode_peerinfo(encoded: Union[bytes, str]) -> KadPeerInfo: ip = lines[1] port = lines[2] peer_id = lines[3] - peer_info = create_kad_peerinfo(peer_id, ip, port) + # ignore typing for kad + peer_info = create_kad_peerinfo(peer_id, ip, port) # type: ignore return peer_info diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index a6e5b48..2fa3292 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -22,7 +22,6 @@ class Mplex(IMuxedConn): secured_conn: ISecureConn raw_conn: IRawConnection initiator: bool - generic_protocol_handler = None peer_id: ID buffers: Dict[int, "asyncio.Queue[bytes]"] stream_queue: "asyncio.Queue[int]" diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 02d981b..b400ec1 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -15,9 +15,8 @@ from libp2p.transport.typing import THandler class TCPListener(IListener): multiaddrs: List[Multiaddr] server = None - handler = None - def __init__(self, handler_function: THandler = None) -> None: + def __init__(self, handler_function: THandler) -> None: self.multiaddrs = [] self.server = None self.handler = handler_function diff --git a/libp2p/transport/typing.py b/libp2p/transport/typing.py index 147fe11..6d0047c 100644 --- a/libp2p/transport/typing.py +++ b/libp2p/transport/typing.py @@ -1,4 +1,4 @@ from asyncio import StreamReader, StreamWriter -from typing import Callable +from typing import Awaitable, Callable -THandler = Callable[[StreamReader, StreamWriter], None] +THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]] diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 77bda39..9b2ddba 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,4 +1,4 @@ -from typing import Dict, Sequence +from typing import Mapping, Sequence from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.typing import GenericProtocolHandlerFn @@ -17,7 +17,9 @@ class TransportUpgrader: security_multistream: SecurityMultistream muxer: Sequence[str] - def __init__(self, secOpt: Dict[TProtocol, ISecureTransport], muxerOpt: Sequence[str]) -> None: + def __init__( + self, secOpt: Mapping[TProtocol, ISecureTransport], muxerOpt: Sequence[str] + ) -> None: # Store security option self.security_multistream = SecurityMultistream() for key in secOpt: @@ -45,7 +47,7 @@ class TransportUpgrader: @staticmethod def upgrade_connection( - conn: IRawConnection, generic_protocol_handler: GenericProtocolHandlerFn, peer_id: ID + conn: ISecureConn, generic_protocol_handler: GenericProtocolHandlerFn, peer_id: ID ) -> Mplex: """ Upgrade raw connection to muxed connection diff --git a/libp2p/typing.py b/libp2p/typing.py index 4b5f5f2..08631dc 100644 --- a/libp2p/typing.py +++ b/libp2p/typing.py @@ -1,11 +1,13 @@ -from typing import Awaitable, Callable, NewType, Union +from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union from libp2p.network.connection.raw_connection_interface import IRawConnection -from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedStream + +if TYPE_CHECKING: + from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401 + from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401 TProtocol = NewType("TProtocol", str) -StreamHandlerFn = Callable[[INetStream], Awaitable[None]] +StreamHandlerFn = Callable[["INetStream"], Awaitable[None]] -NegotiableTransport = Union[IMuxedStream, IRawConnection] +NegotiableTransport = Union["IMuxedStream", IRawConnection] diff --git a/mypy.ini b/mypy.ini index c062136..3da7cfa 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,3 +10,6 @@ disallow_untyped_calls = True warn_redundant_casts = True warn_unused_configs = True strict_equality = True + +[mypy-libp2p.kademlia.*] +ignore_errors = True diff --git a/tox.ini b/tox.ini index 8e77c17..c735abb 100644 --- a/tox.ini +++ b/tox.ini @@ -20,7 +20,9 @@ include_trailing_comma=True force_grid_wrap=0 use_parentheses=True line_length=100 -skip_glob=*_pb2*.py +skip_glob= + *_pb2*.py + *.pyi [testenv] deps = @@ -36,8 +38,7 @@ basepython = basepython = python3 extras = dev commands = - # NOTE: disabling `mypy` until we get typing sorted in this repo - # mypy -p libp2p -p examples --config-file {toxinidir}/mypy.ini - black --check libp2p tests examples setup.py - isort --recursive --check-only libp2p tests examples setup.py - flake8 libp2p tests examples setup.py + mypy -p libp2p -p examples --config-file {toxinidir}/mypy.ini + black --check libp2p tests examples setup.py + isort --recursive --check-only libp2p tests examples setup.py + flake8 libp2p tests examples setup.py