Merge pull request #129 from zaibon/peer_routing

[WIP] kademlia dht
This commit is contained in:
Alex Haynes 2019-04-17 21:50:03 -04:00 committed by GitHub
commit cce226c714
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 501 additions and 269 deletions

View File

@ -23,6 +23,7 @@ async def cleanup_done_tasks():
# Some sleep necessary to context switch
await asyncio.sleep(3)
def initialize_default_swarm(
id_opt=None, transport_opt=None,
muxer_opt=None, sec_opt=None, peerstore_opt=None):
@ -54,6 +55,7 @@ def initialize_default_swarm(
return swarm_opt
async def new_node(
swarm_opt=None, id_opt=None, transport_opt=None,
muxer_opt=None, sec_opt=None, peerstore_opt=None):

View File

@ -1,16 +0,0 @@
from abc import ABC, abstractmethod
# pylint: disable=too-few-public-methods
class IAdvertiser(ABC):
def __init__(self):
pass
@abstractmethod
def advertise(self, service):
"""
Advertise providing a specific service to the network
:param service: service that you provide
:raise Exception: network error
"""

View File

@ -1,17 +0,0 @@
from abc import ABC, abstractmethod
# pylint: disable=too-few-public-methods
class IDiscoverer(ABC):
def __init__(self):
pass
@abstractmethod
def find_peers(self, service):
"""
Find peers on the networking providing a particular service
:param service: service that peers must provide
:return: PeerInfo generator that yields PeerInfo objects for discovered peers
:raise Exception: network error
"""

View File

@ -2,4 +2,4 @@
Kademlia is a Python implementation of the Kademlia protocol which
utilizes the asyncio library.
"""
__version__ = "1.1"
__version__ = "2.0"

View File

@ -1,17 +1,21 @@
from collections import Counter
import logging
from .kademlia.node import Node, NodeHeap
from .kademlia.utils import gather_dict
log = logging.getLogger(__name__)
from .node import Node, NodeHeap
from .utils import gather_dict
log = logging.getLogger(__name__) # pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
class SpiderCrawl:
"""
Crawl the network and look for given 160-bit keys.
"""
def __init__(self, protocol, node, peers, ksize, alpha):
# pylint: disable=too-many-arguments
"""
Create a new C{SpiderCrawl}er.
@ -29,7 +33,7 @@ class SpiderCrawl:
self.alpha = alpha
self.node = node
self.nearest = NodeHeap(self.node, self.ksize)
self.lastIDsCrawled = []
self.last_ids_crawled = []
log.info("creating spider with peers: %s", peers)
self.nearest.push(peers)
@ -38,7 +42,7 @@ class SpiderCrawl:
Get either a value or list of nodes.
Args:
rpcmethod: The protocol's callfindValue or callFindNode.
rpcmethod: The protocol's callfindValue or call_find_node.
The process:
1. calls find_* to current ALPHA nearest not already queried nodes,
@ -51,75 +55,76 @@ class SpiderCrawl:
"""
log.info("crawling network with nearest: %s", str(tuple(self.nearest)))
count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled:
if self.nearest.get_ids() == self.last_ids_crawled:
count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs()
self.last_ids_crawled = self.nearest.get_ids()
ds = {}
for peer in self.nearest.getUncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer)
found = await gather_dict(ds)
return await self._nodesFound(found)
dicts = {}
for peer in self.nearest.get_uncontacted()[:count]:
dicts[peer.id] = rpcmethod(peer, self.node)
self.nearest.mark_contacted(peer)
found = await gather_dict(dicts)
return await self._nodes_found(found)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
raise NotImplementedError
class ValueSpiderCrawl(SpiderCrawl):
def __init__(self, protocol, node, peers, ksize, alpha):
# pylint: disable=too-many-arguments
SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha)
# keep track of the single nearest node without value - per
# section 2.3 so we can set the key there if found
self.nearestWithoutValue = NodeHeap(self.node, 1)
self.nearest_without_value = NodeHeap(self.node, 1)
async def find(self):
"""
Find either the closest nodes or the value requested.
"""
return await self._find(self.protocol.callFindValue)
return await self._find(self.protocol.call_find_value)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
"""
Handle the result of an iteration in _find.
"""
toremove = []
foundValues = []
found_values = []
for peerid, response in responses.items():
response = RPCFindResponse(response)
if not response.happened():
toremove.append(peerid)
elif response.hasValue():
foundValues.append(response.getValue())
elif response.has_value():
found_values.append(response.get_value())
else:
peer = self.nearest.getNodeById(peerid)
self.nearestWithoutValue.push(peer)
self.nearest.push(response.getNodeList())
peer = self.nearest.get_node(peerid)
self.nearest_without_value.push(peer)
self.nearest.push(response.get_node_list())
self.nearest.remove(toremove)
if len(foundValues) > 0:
return await self._handleFoundValues(foundValues)
if self.nearest.allBeenContacted():
if found_values:
return await self._handle_found_values(found_values)
if self.nearest.have_contacted_all():
# not found!
return None
return await self.find()
async def _handleFoundValues(self, values):
async def _handle_found_values(self, values):
"""
We got some values! Exciting. But let's make sure
they're all the same or freak out a little bit. Also,
make sure we tell the nearest node that *didn't* have
the value to store it.
"""
valueCounts = Counter(values)
if len(valueCounts) != 1:
value_counts = Counter(values)
if len(value_counts) != 1:
log.warning("Got multiple values for key %i: %s",
self.node.long_id, str(values))
value = valueCounts.most_common(1)[0][0]
value = value_counts.most_common(1)[0][0]
peerToSaveTo = self.nearestWithoutValue.popleft()
if peerToSaveTo is not None:
await self.protocol.callStore(peerToSaveTo, self.node.id, value)
peer = self.nearest_without_value.popleft()
if peer:
await self.protocol.call_store(peer, self.node.id, value)
return value
@ -128,9 +133,9 @@ class NodeSpiderCrawl(SpiderCrawl):
"""
Find the closest nodes.
"""
return await self._find(self.protocol.callFindNode)
return await self._find(self.protocol.call_find_node)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
"""
Handle the result of an iteration in _find.
"""
@ -140,10 +145,10 @@ class NodeSpiderCrawl(SpiderCrawl):
if not response.happened():
toremove.append(peerid)
else:
self.nearest.push(response.getNodeList())
self.nearest.push(response.get_node_list())
self.nearest.remove(toremove)
if self.nearest.allBeenContacted():
if self.nearest.have_contacted_all():
return list(self.nearest)
return await self.find()
@ -166,13 +171,13 @@ class RPCFindResponse:
"""
return self.response[0]
def hasValue(self):
def has_value(self):
return isinstance(self.response[1], dict)
def getValue(self):
def get_value(self):
return self.response[1]['value']
def getNodeList(self):
def get_node_list(self):
"""
Get the node list in the response. If there's no value, this should
be set.

View File

@ -6,16 +6,17 @@ import pickle
import asyncio
import logging
from .kademlia.protocol import KademliaProtocol
from .kademlia.utils import digest
from .kademlia.storage import ForgetfulStorage
from .kademlia.node import Node
from .kademlia.crawling import ValueSpiderCrawl
from .kademlia.crawling import NodeSpiderCrawl
from .protocol import KademliaProtocol
from .utils import digest
from .storage import ForgetfulStorage
from .node import Node
from .crawling import ValueSpiderCrawl
from .crawling import NodeSpiderCrawl
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
# pylint: disable=too-many-instance-attributes
class Server:
"""
High level view of a node instance. This is the object that should be
@ -57,7 +58,7 @@ class Server:
def _create_protocol(self):
return self.protocol_class(self.node, self.storage, self.ksize)
def listen(self, port, interface='0.0.0.0'):
async def listen(self, port, interface='0.0.0.0'):
"""
Start listening on the given port.
@ -68,7 +69,7 @@ class Server:
local_addr=(interface, port))
log.info("Node %i listening on %s:%i",
self.node.long_id, interface, port)
self.transport, self.protocol = loop.run_until_complete(listen)
self.transport, self.protocol = await listen
# finally, schedule refreshing table
self.refresh_table()
@ -83,22 +84,22 @@ class Server:
Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper).
"""
ds = []
for node_id in self.protocol.getRefreshIDs():
results = []
for node_id in self.protocol.get_refresh_ids():
node = Node(node_id)
nearest = self.protocol.router.findNeighbors(node, self.alpha)
nearest = self.protocol.router.find_neighbors(node, self.alpha)
spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
ds.append(spider.find())
results.append(spider.find())
# do our crawling
await asyncio.gather(*ds)
await asyncio.gather(*results)
# now republish keys older than one hour
for dkey, value in self.storage.iteritemsOlderThan(3600):
for dkey, value in self.storage.iter_older_than(3600):
await self.set_digest(dkey, value)
def bootstrappableNeighbors(self):
def bootstrappable_neighbors(self):
"""
Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for
use as an argument to the bootstrap method.
@ -108,7 +109,7 @@ class Server:
storing them if this server is going down for a while. When it comes
back up, the list of nodes can be used to bootstrap.
"""
neighbors = self.protocol.router.findNeighbors(self.node)
neighbors = self.protocol.router.find_neighbors(self.node)
return [tuple(n)[-2:] for n in neighbors]
async def bootstrap(self, addrs):
@ -145,8 +146,8 @@ class Server:
if self.storage.get(dkey) is not None:
return self.storage.get(dkey)
node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node)
if len(nearest) == 0:
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to get key %s", key)
return None
spider = ValueSpiderCrawl(self.protocol, node, nearest,
@ -172,8 +173,8 @@ class Server:
"""
node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node)
if len(nearest) == 0:
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to set key %s",
dkey.hex())
return False
@ -184,14 +185,14 @@ class Server:
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
# if this node is close too, then store here as well
biggest = max([n.distanceTo(node) for n in nodes])
if self.node.distanceTo(node) < biggest:
biggest = max([n.distance_to(node) for n in nodes])
if self.node.distance_to(node) < biggest:
self.storage[dkey] = value
ds = [self.protocol.callStore(n, dkey, value) for n in nodes]
results = [self.protocol.call_store(n, dkey, value) for n in nodes]
# return true only if at least one store call succeeded
return any(await asyncio.gather(*ds))
return any(await asyncio.gather(*results))
def saveState(self, fname):
def save_state(self, fname):
"""
Save the state of this node (the alpha/ksize/id/immediate neighbors)
to a cache file with the given fname.
@ -201,29 +202,29 @@ class Server:
'ksize': self.ksize,
'alpha': self.alpha,
'id': self.node.id,
'neighbors': self.bootstrappableNeighbors()
'neighbors': self.bootstrappable_neighbors()
}
if len(data['neighbors']) == 0:
if not data['neighbors']:
log.warning("No known neighbors, so not writing to cache.")
return
with open(fname, 'wb') as f:
pickle.dump(data, f)
with open(fname, 'wb') as file:
pickle.dump(data, file)
@classmethod
def loadState(self, fname):
def load_state(cls, fname):
"""
Load the state of this node (the alpha/ksize/id/immediate neighbors)
from a cache file with the given fname.
"""
log.info("Loading state from %s", fname)
with open(fname, 'rb') as f:
data = pickle.load(f)
s = Server(data['ksize'], data['alpha'], data['id'])
if len(data['neighbors']) > 0:
s.bootstrap(data['neighbors'])
return s
with open(fname, 'rb') as file:
data = pickle.load(file)
svr = Server(data['ksize'], data['alpha'], data['id'])
if data['neighbors']:
svr.bootstrap(data['neighbors'])
return svr
def saveStateRegularly(self, fname, frequency=600):
def save_state_regularly(self, fname, frequency=600):
"""
Save the state of node with a given regularity to the given
filename.
@ -233,10 +234,10 @@ class Server:
frequency: Frequency in seconds that the state should be saved.
By default, 10 minutes.
"""
self.saveState(fname)
self.save_state(fname)
loop = asyncio.get_event_loop()
self.save_state_loop = loop.call_later(frequency,
self.saveStateRegularly,
self.save_state_regularly,
fname,
frequency)
@ -246,13 +247,11 @@ def check_dht_value_type(value):
Checks to see if the type of the value is a valid type for
placing in the dht.
"""
typeset = set(
[
int,
float,
bool,
str,
bytes,
]
)
return type(value) in typeset
typeset = [
int,
float,
bool,
str,
bytes
]
return type(value) in typeset # pylint: disable=unidiomatic-typecheck

View File

@ -4,15 +4,15 @@ import heapq
class Node:
def __init__(self, node_id, ip=None, port=None):
self.id = node_id
self.ip = ip
self.id = node_id # pylint: disable=invalid-name
self.ip = ip # pylint: disable=invalid-name
self.port = port
self.long_id = int(node_id.hex(), 16)
def sameHomeAs(self, node):
def same_home_as(self, node):
return self.ip == node.ip and self.port == node.port
def distanceTo(self, node):
def distance_to(self, node):
"""
Get the distance between this node and another.
"""
@ -47,7 +47,7 @@ class NodeHeap:
self.contacted = set()
self.maxsize = maxsize
def remove(self, peerIDs):
def remove(self, peers):
"""
Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's
@ -55,34 +55,32 @@ class NodeHeap:
removal of nodes may not change the visible size as previously added
nodes suddenly become visible.
"""
peerIDs = set(peerIDs)
if len(peerIDs) == 0:
peers = set(peers)
if not peers:
return
nheap = []
for distance, node in self.heap:
if node.id not in peerIDs:
if node.id not in peers:
heapq.heappush(nheap, (distance, node))
self.heap = nheap
def getNodeById(self, node_id):
def get_node(self, node_id):
for _, node in self.heap:
if node.id == node_id:
return node
return None
def allBeenContacted(self):
return len(self.getUncontacted()) == 0
def have_contacted_all(self):
return len(self.get_uncontacted()) == 0
def getIDs(self):
def get_ids(self):
return [n.id for n in self]
def markContacted(self, node):
def mark_contacted(self, node):
self.contacted.add(node.id)
def popleft(self):
if len(self) > 0:
return heapq.heappop(self.heap)[1]
return None
return heapq.heappop(self.heap)[1] if self else None
def push(self, nodes):
"""
@ -95,7 +93,7 @@ class NodeHeap:
for node in nodes:
if node not in self:
distance = self.node.distanceTo(node)
distance = self.node.distance_to(node)
heapq.heappush(self.heap, (distance, node))
def __len__(self):
@ -106,10 +104,10 @@ class NodeHeap:
return iter(map(itemgetter(1), nodes))
def __contains__(self, node):
for _, n in self.heap:
if node.id == n.id:
for _, other in self.heap:
if node.id == other.id:
return True
return False
def getUncontacted(self):
def get_uncontacted(self):
return [n for n in self if n.id not in self.contacted]

View File

@ -4,41 +4,56 @@ import logging
from rpcudp.protocol import RPCProtocol
from .kademlia.node import Node
from .kademlia.routing import RoutingTable
from .kademlia.utils import digest
from .node import Node
from .routing import RoutingTable
from .utils import digest
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class KademliaProtocol(RPCProtocol):
def __init__(self, sourceNode, storage, ksize):
"""
There are four main RPCs in the Kademlia protocol
PING, STORE, FIND_NODE, FIND_VALUE
PING probes if a node is still online
STORE instructs a node to store (key, value)
FIND_NODE takes a 160-bit ID and gets back
(ip, udp_port, node_id) for k closest nodes to target
FIND_VALUE behaves like FIND_NODE unless a value is stored
"""
def __init__(self, source_node, storage, ksize):
RPCProtocol.__init__(self)
self.router = RoutingTable(self, ksize, sourceNode)
self.router = RoutingTable(self, ksize, source_node)
self.storage = storage
self.sourceNode = sourceNode
self.source_node = source_node
def getRefreshIDs(self):
def get_refresh_ids(self):
"""
Get ids to search for to keep old buckets up to date.
"""
ids = []
for bucket in self.router.getLonelyBuckets():
for bucket in self.router.lonely_buckets():
rid = random.randint(*bucket.range).to_bytes(20, byteorder='big')
ids.append(rid)
return ids
def rpc_stun(self, sender):
def rpc_add_provider(self, sender, nodeid, key):
pass
def rpc_get_providers(self, sender, nodeid, key):
pass
def rpc_stun(self, sender): # pylint: disable=no-self-use
return sender
def rpc_ping(self, sender, nodeid):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
return self.sourceNode.id
self.welcome_if_new(source)
return self.source_node.id
def rpc_store(self, sender, nodeid, key, value):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
log.debug("got a store request from %s, storing '%s'='%s'",
sender, key.hex(), value)
self.storage[key] = value
@ -48,42 +63,42 @@ class KademliaProtocol(RPCProtocol):
log.info("finding neighbors of %i in local table",
int(nodeid.hex(), 16))
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
node = Node(key)
neighbors = self.router.findNeighbors(node, exclude=source)
neighbors = self.router.find_neighbors(node, exclude=source)
return list(map(tuple, neighbors))
def rpc_find_value(self, sender, nodeid, key):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
value = self.storage.get(key, None)
if value is None:
return self.rpc_find_node(sender, nodeid, key)
return {'value': value}
async def callFindNode(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.find_node(address, self.sourceNode.id,
nodeToFind.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_find_node(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_node(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def callFindValue(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.find_value(address, self.sourceNode.id,
nodeToFind.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_find_value(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_value(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def callPing(self, nodeToAsk):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.ping(address, self.sourceNode.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_ping(self, node_to_ask):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.ping(address, self.source_node.id)
return self.handle_call_response(result, node_to_ask)
async def callStore(self, nodeToAsk, key, value):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.store(address, self.sourceNode.id, key, value)
return self.handleCallResponse(result, nodeToAsk)
async def call_store(self, node_to_ask, key, value):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.store(address, self.source_node.id, key, value)
return self.handle_call_response(result, node_to_ask)
def welcomeIfNewNode(self, node):
def welcome_if_new(self, node):
"""
Given a new node, send it all the keys/values it should be storing,
then add it to the routing table.
@ -97,32 +112,32 @@ class KademliaProtocol(RPCProtocol):
is closer than the closest in that list, then store the key/value
on the new node (per section 2.5 of the paper)
"""
if not self.router.isNewNode(node):
if not self.router.is_new_node(node):
return
log.info("never seen %s before, adding to router", node)
for key, value in self.storage.items():
for key, value in self.storage:
keynode = Node(digest(key))
neighbors = self.router.findNeighbors(keynode)
if len(neighbors) > 0:
last = neighbors[-1].distanceTo(keynode)
newNodeClose = node.distanceTo(keynode) < last
first = neighbors[0].distanceTo(keynode)
thisNodeClosest = self.sourceNode.distanceTo(keynode) < first
if len(neighbors) == 0 or (newNodeClose and thisNodeClosest):
asyncio.ensure_future(self.callStore(node, key, value))
self.router.addContact(node)
neighbors = self.router.find_neighbors(keynode)
if neighbors:
last = neighbors[-1].distance_to(keynode)
new_node_close = node.distance_to(keynode) < last
first = neighbors[0].distance_to(keynode)
this_closest = self.source_node.distance_to(keynode) < first
if not neighbors or (new_node_close and this_closest):
asyncio.ensure_future(self.call_store(node, key, value))
self.router.add_contact(node)
def handleCallResponse(self, result, node):
def handle_call_response(self, result, node):
"""
If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table.
"""
if not result[0]:
log.warning("no response from %s, removing from router", node)
self.router.removeContact(node)
self.router.remove_contact(node)
return result
log.info("got successful response from %s", node)
self.welcomeIfNewNode(node)
self.welcome_if_new(node)
return result

View File

@ -4,22 +4,28 @@ import operator
import asyncio
from collections import OrderedDict
from .kademlia.utils import OrderedSet, sharedPrefix, bytesToBitString
from .utils import OrderedSet, shared_prefix, bytes_to_bit_string
class KBucket:
"""
each node keeps a list of (ip, udp_port, node_id)
for nodes of distance between 2^i and 2^(i+1)
this list that every node keeps is a k-bucket
each k-bucket implements a last seen eviction
policy except that live nodes are never removed
"""
def __init__(self, rangeLower, rangeUpper, ksize):
self.range = (rangeLower, rangeUpper)
self.nodes = OrderedDict()
self.replacementNodes = OrderedSet()
self.touchLastUpdated()
self.replacement_nodes = OrderedSet()
self.touch_last_updated()
self.ksize = ksize
def touchLastUpdated(self):
self.lastUpdated = time.monotonic()
def touch_last_updated(self):
self.last_updated = time.monotonic()
def getNodes(self):
def get_nodes(self):
return list(self.nodes.values())
def split(self):
@ -31,23 +37,23 @@ class KBucket:
bucket.nodes[node.id] = node
return (one, two)
def removeNode(self, node):
def remove_node(self, node):
if node.id not in self.nodes:
return
# delete node, and see if we can add a replacement
del self.nodes[node.id]
if len(self.replacementNodes) > 0:
newnode = self.replacementNodes.pop()
if self.replacement_nodes:
newnode = self.replacement_nodes.pop()
self.nodes[newnode.id] = newnode
def hasInRange(self, node):
def has_in_range(self, node):
return self.range[0] <= node.long_id <= self.range[1]
def isNewNode(self, node):
def is_new_node(self, node):
return node.id not in self.nodes
def addNode(self, node):
def add_node(self, node):
"""
Add a C{Node} to the C{KBucket}. Return True if successful,
False if the bucket is full.
@ -61,14 +67,14 @@ class KBucket:
elif len(self) < self.ksize:
self.nodes[node.id] = node
else:
self.replacementNodes.push(node)
self.replacement_nodes.push(node)
return False
return True
def depth(self):
vals = self.nodes.values()
sp = sharedPrefix([bytesToBitString(n.id) for n in vals])
return len(sp)
sprefix = shared_prefix([bytes_to_bit_string(n.id) for n in vals])
return len(sprefix)
def head(self):
return list(self.nodes.values())[0]
@ -82,11 +88,11 @@ class KBucket:
class TableTraverser:
def __init__(self, table, startNode):
index = table.getBucketFor(startNode)
table.buckets[index].touchLastUpdated()
self.currentNodes = table.buckets[index].getNodes()
self.leftBuckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index + 1):]
index = table.get_bucket_for(startNode)
table.buckets[index].touch_last_updated()
self.current_nodes = table.buckets[index].get_nodes()
self.left_buckets = table.buckets[:index]
self.right_buckets = table.buckets[(index + 1):]
self.left = True
def __iter__(self):
@ -96,16 +102,16 @@ class TableTraverser:
"""
Pop an item from the left subtree, then right, then left, etc.
"""
if len(self.currentNodes) > 0:
return self.currentNodes.pop()
if self.current_nodes:
return self.current_nodes.pop()
if self.left and len(self.leftBuckets) > 0:
self.currentNodes = self.leftBuckets.pop().getNodes()
if self.left and self.left_buckets:
self.current_nodes = self.left_buckets.pop().get_nodes()
self.left = False
return next(self)
if len(self.rightBuckets) > 0:
self.currentNodes = self.rightBuckets.pop(0).getNodes()
if self.right_buckets:
self.current_nodes = self.right_buckets.pop(0).get_nodes()
self.left = True
return next(self)
@ -127,58 +133,60 @@ class RoutingTable:
def flush(self):
self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
def splitBucket(self, index):
def split_bucket(self, index):
one, two = self.buckets[index].split()
self.buckets[index] = one
self.buckets.insert(index + 1, two)
def getLonelyBuckets(self):
def lonely_buckets(self):
"""
Get all of the buckets that haven't been updated in over
an hour.
"""
hrago = time.monotonic() - 3600
return [b for b in self.buckets if b.lastUpdated < hrago]
return [b for b in self.buckets if b.last_updated < hrago]
def removeContact(self, node):
index = self.getBucketFor(node)
self.buckets[index].removeNode(node)
def remove_contact(self, node):
index = self.get_bucket_for(node)
self.buckets[index].remove_node(node)
def isNewNode(self, node):
index = self.getBucketFor(node)
return self.buckets[index].isNewNode(node)
def is_new_node(self, node):
index = self.get_bucket_for(node)
return self.buckets[index].is_new_node(node)
def addContact(self, node):
index = self.getBucketFor(node)
def add_contact(self, node):
index = self.get_bucket_for(node)
bucket = self.buckets[index]
# this will succeed unless the bucket is full
if bucket.addNode(node):
if bucket.add_node(node):
return
# Per section 4.2 of paper, split if the bucket has the node
# in its range or if the depth is not congruent to 0 mod 5
if bucket.hasInRange(self.node) or bucket.depth() % 5 != 0:
self.splitBucket(index)
self.addContact(node)
if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0:
self.split_bucket(index)
self.add_contact(node)
else:
asyncio.ensure_future(self.protocol.callPing(bucket.head()))
asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
def getBucketFor(self, node):
def get_bucket_for(self, node):
"""
Get the index of the bucket that the given node would fall into.
"""
for index, bucket in enumerate(self.buckets):
if node.long_id < bucket.range[1]:
return index
# we should never be here, but make linter happy
return None
def findNeighbors(self, node, k=None, exclude=None):
def find_neighbors(self, node, k=None, exclude=None):
k = k or self.ksize
nodes = []
for neighbor in TableTraverser(self, node):
notexcluded = exclude is None or not neighbor.sameHomeAs(exclude)
notexcluded = exclude is None or not neighbor.same_home_as(exclude)
if neighbor.id != node.id and notexcluded:
heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor))
heapq.heappush(nodes, (node.distance_to(neighbor), neighbor))
if len(nodes) == k:
break

78
libp2p/kademlia/rpc.proto Normal file
View File

@ -0,0 +1,78 @@
// Record represents a dht record that contains a value
// for a key value pair
message Record {
// The key that references this record
bytes key = 1;
// The actual value this record is storing
bytes value = 2;
// Note: These fields were removed from the Record message
// hash of the authors public key
//optional string author = 3;
// A PKI signature for the key+value+author
//optional bytes signature = 4;
// Time the record was received, set by receiver
string timeReceived = 5;
};
message Message {
enum MessageType {
PUT_VALUE = 0;
GET_VALUE = 1;
ADD_PROVIDER = 2;
GET_PROVIDERS = 3;
FIND_NODE = 4;
PING = 5;
}
enum ConnectionType {
// sender does not have a connection to peer, and no extra information (default)
NOT_CONNECTED = 0;
// sender has a live connection to peer
CONNECTED = 1;
// sender recently connected to peer
CAN_CONNECT = 2;
// sender recently tried to connect to peer repeatedly but failed to connect
// ("try" here is loose, but this should signal "made strong effort, failed")
CANNOT_CONNECT = 3;
}
message Peer {
// ID of a given peer.
bytes id = 1;
// multiaddrs for a given peer
repeated bytes addrs = 2;
// used to signal the sender's connection capabilities to the peer
ConnectionType connection = 3;
}
// defines what type of message it is.
MessageType type = 1;
// defines what coral cluster level this query/response belongs to.
// in case we want to implement coral's cluster rings in the future.
int32 clusterLevelRaw = 10; // NOT USED
// Used to specify the key associated with this message.
// PUT_VALUE, GET_VALUE, ADD_PROVIDER, GET_PROVIDERS
bytes key = 2;
// Used to return a value
// PUT_VALUE, GET_VALUE
Record record = 3;
// Used to return peers closer to a key in a query
// GET_VALUE, GET_PROVIDERS, FIND_NODE
repeated Peer closerPeers = 8;
// Used to return Providers
// GET_VALUE, ADD_PROVIDER, GET_PROVIDERS
repeated Peer providerPeers = 9;
}

View File

@ -2,44 +2,45 @@ import time
from itertools import takewhile
import operator
from collections import OrderedDict
from abc import abstractmethod, ABC
class IStorage:
class IStorage(ABC):
"""
Local storage for this node.
IStorage implementations of get must return the same type as put in by set
"""
@abstractmethod
def __setitem__(self, key, value):
"""
Set a key to the given value.
"""
raise NotImplementedError
@abstractmethod
def __getitem__(self, key):
"""
Get the given key. If item doesn't exist, raises C{KeyError}
"""
raise NotImplementedError
@abstractmethod
def get(self, key, default=None):
"""
Get given key. If not found, return default.
"""
raise NotImplementedError
def iteritemsOlderThan(self, secondsOld):
@abstractmethod
def iter_older_than(self, seconds_old):
"""
Return the an iterator over (key, value) tuples for items older
than the given secondsOld.
than the given seconds_old.
"""
raise NotImplementedError
@abstractmethod
def __iter__(self):
"""
Get the iterator for this storage, should yield tuple of (key, value)
"""
raise NotImplementedError
class ForgetfulStorage(IStorage):
@ -57,7 +58,7 @@ class ForgetfulStorage(IStorage):
self.cull()
def cull(self):
for _, _ in self.iteritemsOlderThan(self.ttl):
for _, _ in self.iter_older_than(self.ttl):
self.data.popitem(last=False)
def get(self, key, default=None):
@ -70,27 +71,23 @@ class ForgetfulStorage(IStorage):
self.cull()
return self.data[key][1]
def __iter__(self):
self.cull()
return iter(self.data)
def __repr__(self):
self.cull()
return repr(self.data)
def iteritemsOlderThan(self, secondsOld):
minBirthday = time.monotonic() - secondsOld
zipped = self._tripleIterable()
matches = takewhile(lambda r: minBirthday >= r[1], zipped)
def iter_older_than(self, seconds_old):
min_birthday = time.monotonic() - seconds_old
zipped = self._triple_iter()
matches = takewhile(lambda r: min_birthday >= r[1], zipped)
return list(map(operator.itemgetter(0, 2), matches))
def _tripleIterable(self):
def _triple_iter(self):
ikeys = self.data.keys()
ibirthday = map(operator.itemgetter(0), self.data.values())
ivalues = map(operator.itemgetter(1), self.data.values())
return zip(ikeys, ibirthday, ivalues)
def items(self):
def __iter__(self):
self.cull()
ikeys = self.data.keys()
ivalues = map(operator.itemgetter(1), self.data.values())

View File

@ -6,16 +6,16 @@ import operator
import asyncio
async def gather_dict(d):
cors = list(d.values())
async def gather_dict(dic):
cors = list(dic.values())
results = await asyncio.gather(*cors)
return dict(zip(d.keys(), results))
return dict(zip(dic.keys(), results))
def digest(s):
if not isinstance(s, bytes):
s = str(s).encode('utf8')
return hashlib.sha1(s).digest()
def digest(string):
if not isinstance(string, bytes):
string = str(string).encode('utf8')
return hashlib.sha1(string).digest()
class OrderedSet(list):
@ -34,7 +34,7 @@ class OrderedSet(list):
self.append(thing)
def sharedPrefix(args):
def shared_prefix(args):
"""
Find the shared prefix between the strings.
@ -52,6 +52,6 @@ def sharedPrefix(args):
return args[0][:i]
def bytesToBitString(bites):
def bytes_to_bit_string(bites):
bits = [bin(bite)[2:].rjust(8, '0') for bite in bites]
return "".join(bits)

View File

@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
# pylint: disable=too-few-public-methods
class IContentRouting(ABC):
@abstractmethod
def provide(self, cid, announce=True):
"""
Provide adds the given cid to the content routing system. If announce is True,
it also announces it, otherwise it is just kept in the local
accounting of which objects are being provided.
"""
@abstractmethod
def find_provider_iter(self, cid, count):
"""
Search for peers who are able to provide a given key
returns an iterator of peer.PeerInfo
"""
class IPeerRouting(ABC):
@abstractmethod
def find_peer(self, peer_id):
"""
Find specific Peer
FindPeer searches for a peer with given peer_id, returns a peer.PeerInfo
with relevant addresses.
"""

View File

View File

@ -0,0 +1,21 @@
from libp2p.routing.interfaces import IContentRouting
class KadmeliaContentRouter(IContentRouting):
def provide(self, cid, announce=True):
"""
Provide adds the given cid to the content routing system. If announce is True,
it also announces it, otherwise it is just kept in the local
accounting of which objects are being provided.
"""
# the DHT finds the closest peers to `key` using the `FIND_NODE` RPC
# then sends a `ADD_PROVIDER` RPC with its own `PeerInfo` to each of these peers.
pass
def find_provider_iter(self, cid, count):
"""
Search for peers who are able to provide a given key
returns an iterator of peer.PeerInfo
"""
pass

View File

@ -0,0 +1,31 @@
from libp2p.routing.interfaces import IPeerRouting
from libp2p.kademlia.utils import digest
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerdata import PeerData
class KadmeliaPeerRouter(IPeerRouting):
# pylint: disable=too-few-public-methods
def __init__(self, dht_server):
self.server = dht_server
def find_peer(self, peer_id):
"""
Find specific Peer
FindPeer searches for a peer with given peer_id, returns a peer.PeerInfo
with relevant addresses.
"""
value = self.server.get(peer_id)
return decode_peerinfo(value)
def decode_peerinfo(encoded):
if isinstance(encoded, bytes):
encoded = encoded.decode()
lines = encoded.splitlines()
peer_id = lines[0]
addrs = lines[1:]
peer_data = PeerData()
peer_data.add_addrs(addrs)
return PeerInfo(peer_id, addrs)

View File

@ -22,6 +22,8 @@ setuptools.setup(
"base58",
"pymultihash",
"multiaddr",
"rpcudp",
"umsgpack",
"grpcio",
"grpcio-tools",
"lru-dict>=1.1.6"

View File

View File

@ -0,0 +1,78 @@
import pytest
from libp2p.kademlia.network import Server
@pytest.mark.asyncio
async def test_example():
node_a = Server()
await node_a.listen(5678)
node_b = Server()
await node_b.listen(5679)
# Bootstrap the node by connecting to other known nodes, in this case
# replace 123.123.123.123 with the IP of another node and optionally
# give as many ip/port combos as you can for other nodes.
await node_b.bootstrap([("127.0.0.1", 5678)])
# set a value for the key "my-key" on the network
value = "my-value"
key = "my-key"
await node_b.set(key, value)
# get the value associated with "my-key" from the network
assert await node_b.get(key) == value
assert await node_a.get(key) == value
@pytest.mark.parametrize("nodes_nr", [(2**i) for i in range(2, 5)])
@pytest.mark.asyncio
async def test_multiple_nodes_bootstrap_set_get(nodes_nr):
node_bootstrap = Server()
await node_bootstrap.listen(3000 + nodes_nr * 2)
nodes = []
for i in range(nodes_nr):
node = Server()
addrs = [("127.0.0.1", 3000 + nodes_nr * 2)]
await node.listen(3001 + i + nodes_nr * 2)
await node.bootstrap(addrs)
nodes.append(node)
for i, node in enumerate(nodes):
# set a value for the key "my-key" on the network
value = "my awesome value %d" % i
key = "set from %d" % i
await node.set(key, value)
for i in range(nodes_nr):
for node in nodes:
value = "my awesome value %d" % i
key = "set from %d" % i
assert await node.get(key) == value
@pytest.mark.parametrize("nodes_nr", [(2**i) for i in range(2, 5)])
@pytest.mark.asyncio
async def test_multiple_nodes_set_bootstrap_get(nodes_nr):
node_bootstrap = Server()
await node_bootstrap.listen(2000 + nodes_nr * 2)
nodes = []
for i in range(nodes_nr):
node = Server()
addrs = [("127.0.0.1", 2000 + nodes_nr * 2)]
await node.listen(2001 + i + nodes_nr * 2)
await node.bootstrap(addrs)
value = "my awesome value %d" % i
key = "set from %d" % i
await node.set(key, value)
nodes.append(node)
for i in range(nodes_nr):
for node in nodes:
value = "my awesome value %d" % i
key = "set from %d" % i
assert await node.get(key) == value