update kadmelia lib

This commit is contained in:
Christophe de Carvalho Pereira Martins 2019-01-15 18:41:41 +01:00 committed by Christophe de Carvalho
parent 57077cd3b4
commit c24a279d2d
No known key found for this signature in database
GPG Key ID: EFEE139F5CCB06C2
8 changed files with 233 additions and 235 deletions

View File

@ -2,4 +2,4 @@
Kademlia is a Python implementation of the Kademlia protocol which Kademlia is a Python implementation of the Kademlia protocol which
utilizes the asyncio library. utilizes the asyncio library.
""" """
__version__ = "1.1" __version__ = "2.0"

View File

@ -1,16 +1,19 @@
from collections import Counter from collections import Counter
import logging import logging
from .kademlia.node import Node, NodeHeap from .node import Node, NodeHeap
from .kademlia.utils import gather_dict from .utils import gather_dict
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
class SpiderCrawl: class SpiderCrawl:
""" """
Crawl the network and look for given 160-bit keys. Crawl the network and look for given 160-bit keys.
""" """
def __init__(self, protocol, node, peers, ksize, alpha): def __init__(self, protocol, node, peers, ksize, alpha):
""" """
Create a new C{SpiderCrawl}er. Create a new C{SpiderCrawl}er.
@ -29,7 +32,7 @@ class SpiderCrawl:
self.alpha = alpha self.alpha = alpha
self.node = node self.node = node
self.nearest = NodeHeap(self.node, self.ksize) self.nearest = NodeHeap(self.node, self.ksize)
self.lastIDsCrawled = [] self.last_ids_crawled = []
log.info("creating spider with peers: %s", peers) log.info("creating spider with peers: %s", peers)
self.nearest.push(peers) self.nearest.push(peers)
@ -38,7 +41,7 @@ class SpiderCrawl:
Get either a value or list of nodes. Get either a value or list of nodes.
Args: Args:
rpcmethod: The protocol's callfindValue or callFindNode. rpcmethod: The protocol's callfindValue or call_find_node.
The process: The process:
1. calls find_* to current ALPHA nearest not already queried nodes, 1. calls find_* to current ALPHA nearest not already queried nodes,
@ -51,18 +54,18 @@ class SpiderCrawl:
""" """
log.info("crawling network with nearest: %s", str(tuple(self.nearest))) log.info("crawling network with nearest: %s", str(tuple(self.nearest)))
count = self.alpha count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled: if self.nearest.get_ids() == self.last_ids_crawled:
count = len(self.nearest) count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs() self.last_ids_crawled = self.nearest.get_ids()
ds = {} dicts = {}
for peer in self.nearest.getUncontacted()[:count]: for peer in self.nearest.get_uncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node) dicts[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer) self.nearest.mark_contacted(peer)
found = await gather_dict(ds) found = await gather_dict(dicts)
return await self._nodesFound(found) return await self._nodes_found(found)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
raise NotImplementedError raise NotImplementedError
@ -71,55 +74,55 @@ class ValueSpiderCrawl(SpiderCrawl):
SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha) SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha)
# keep track of the single nearest node without value - per # keep track of the single nearest node without value - per
# section 2.3 so we can set the key there if found # section 2.3 so we can set the key there if found
self.nearestWithoutValue = NodeHeap(self.node, 1) self.nearest_without_value = NodeHeap(self.node, 1)
async def find(self): async def find(self):
""" """
Find either the closest nodes or the value requested. Find either the closest nodes or the value requested.
""" """
return await self._find(self.protocol.callFindValue) return await self._find(self.protocol.call_find_value)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
""" """
Handle the result of an iteration in _find. Handle the result of an iteration in _find.
""" """
toremove = [] toremove = []
foundValues = [] found_values = []
for peerid, response in responses.items(): for peerid, response in responses.items():
response = RPCFindResponse(response) response = RPCFindResponse(response)
if not response.happened(): if not response.happened():
toremove.append(peerid) toremove.append(peerid)
elif response.hasValue(): elif response.has_value():
foundValues.append(response.getValue()) found_values.append(response.get_value())
else: else:
peer = self.nearest.getNodeById(peerid) peer = self.nearest.get_node(peerid)
self.nearestWithoutValue.push(peer) self.nearest_without_value.push(peer)
self.nearest.push(response.getNodeList()) self.nearest.push(response.get_node_list())
self.nearest.remove(toremove) self.nearest.remove(toremove)
if len(foundValues) > 0: if found_values:
return await self._handleFoundValues(foundValues) return await self._handle_found_values(found_values)
if self.nearest.allBeenContacted(): if self.nearest.have_contacted_all():
# not found! # not found!
return None return None
return await self.find() return await self.find()
async def _handleFoundValues(self, values): async def _handle_found_values(self, values):
""" """
We got some values! Exciting. But let's make sure We got some values! Exciting. But let's make sure
they're all the same or freak out a little bit. Also, they're all the same or freak out a little bit. Also,
make sure we tell the nearest node that *didn't* have make sure we tell the nearest node that *didn't* have
the value to store it. the value to store it.
""" """
valueCounts = Counter(values) value_counts = Counter(values)
if len(valueCounts) != 1: if len(value_counts) != 1:
log.warning("Got multiple values for key %i: %s", log.warning("Got multiple values for key %i: %s",
self.node.long_id, str(values)) self.node.long_id, str(values))
value = valueCounts.most_common(1)[0][0] value = value_counts.most_common(1)[0][0]
peerToSaveTo = self.nearestWithoutValue.popleft() peer = self.nearest_without_value.popleft()
if peerToSaveTo is not None: if peer:
await self.protocol.callStore(peerToSaveTo, self.node.id, value) await self.protocol.call_store(peer, self.node.id, value)
return value return value
@ -128,9 +131,9 @@ class NodeSpiderCrawl(SpiderCrawl):
""" """
Find the closest nodes. Find the closest nodes.
""" """
return await self._find(self.protocol.callFindNode) return await self._find(self.protocol.call_find_node)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
""" """
Handle the result of an iteration in _find. Handle the result of an iteration in _find.
""" """
@ -140,10 +143,10 @@ class NodeSpiderCrawl(SpiderCrawl):
if not response.happened(): if not response.happened():
toremove.append(peerid) toremove.append(peerid)
else: else:
self.nearest.push(response.getNodeList()) self.nearest.push(response.get_node_list())
self.nearest.remove(toremove) self.nearest.remove(toremove)
if self.nearest.allBeenContacted(): if self.nearest.have_contacted_all():
return list(self.nearest) return list(self.nearest)
return await self.find() return await self.find()
@ -166,13 +169,13 @@ class RPCFindResponse:
""" """
return self.response[0] return self.response[0]
def hasValue(self): def has_value(self):
return isinstance(self.response[1], dict) return isinstance(self.response[1], dict)
def getValue(self): def get_value(self):
return self.response[1]['value'] return self.response[1]['value']
def getNodeList(self): def get_node_list(self):
""" """
Get the node list in the response. If there's no value, this should Get the node list in the response. If there's no value, this should
be set. be set.

View File

@ -6,16 +6,17 @@ import pickle
import asyncio import asyncio
import logging import logging
from .kademlia.protocol import KademliaProtocol from .protocol import KademliaProtocol
from .kademlia.utils import digest from .utils import digest
from .kademlia.storage import ForgetfulStorage from .storage import ForgetfulStorage
from .kademlia.node import Node from .node import Node
from .kademlia.crawling import ValueSpiderCrawl from .crawling import ValueSpiderCrawl
from .kademlia.crawling import NodeSpiderCrawl from .crawling import NodeSpiderCrawl
log = logging.getLogger(__name__) log = logging.getLogger(__name__) # pylint: disable=invalid-name
# pylint: disable=too-many-instance-attributes
class Server: class Server:
""" """
High level view of a node instance. This is the object that should be High level view of a node instance. This is the object that should be
@ -57,7 +58,7 @@ class Server:
def _create_protocol(self): def _create_protocol(self):
return self.protocol_class(self.node, self.storage, self.ksize) return self.protocol_class(self.node, self.storage, self.ksize)
def listen(self, port, interface='0.0.0.0'): async def listen(self, port, interface='0.0.0.0'):
""" """
Start listening on the given port. Start listening on the given port.
@ -68,7 +69,7 @@ class Server:
local_addr=(interface, port)) local_addr=(interface, port))
log.info("Node %i listening on %s:%i", log.info("Node %i listening on %s:%i",
self.node.long_id, interface, port) self.node.long_id, interface, port)
self.transport, self.protocol = loop.run_until_complete(listen) self.transport, self.protocol = await listen
# finally, schedule refreshing table # finally, schedule refreshing table
self.refresh_table() self.refresh_table()
@ -83,22 +84,22 @@ class Server:
Refresh buckets that haven't had any lookups in the last hour Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper). (per section 2.3 of the paper).
""" """
ds = [] results = []
for node_id in self.protocol.getRefreshIDs(): for node_id in self.protocol.get_refresh_ids():
node = Node(node_id) node = Node(node_id)
nearest = self.protocol.router.findNeighbors(node, self.alpha) nearest = self.protocol.router.find_neighbors(node, self.alpha)
spider = NodeSpiderCrawl(self.protocol, node, nearest, spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha) self.ksize, self.alpha)
ds.append(spider.find()) results.append(spider.find())
# do our crawling # do our crawling
await asyncio.gather(*ds) await asyncio.gather(*results)
# now republish keys older than one hour # now republish keys older than one hour
for dkey, value in self.storage.iteritemsOlderThan(3600): for dkey, value in self.storage.iter_older_than(3600):
await self.set_digest(dkey, value) await self.set_digest(dkey, value)
def bootstrappableNeighbors(self): def bootstrappable_neighbors(self):
""" """
Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for
use as an argument to the bootstrap method. use as an argument to the bootstrap method.
@ -108,7 +109,7 @@ class Server:
storing them if this server is going down for a while. When it comes storing them if this server is going down for a while. When it comes
back up, the list of nodes can be used to bootstrap. back up, the list of nodes can be used to bootstrap.
""" """
neighbors = self.protocol.router.findNeighbors(self.node) neighbors = self.protocol.router.find_neighbors(self.node)
return [tuple(n)[-2:] for n in neighbors] return [tuple(n)[-2:] for n in neighbors]
async def bootstrap(self, addrs): async def bootstrap(self, addrs):
@ -145,8 +146,8 @@ class Server:
if self.storage.get(dkey) is not None: if self.storage.get(dkey) is not None:
return self.storage.get(dkey) return self.storage.get(dkey)
node = Node(dkey) node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node) nearest = self.protocol.router.find_neighbors(node)
if len(nearest) == 0: if not nearest:
log.warning("There are no known neighbors to get key %s", key) log.warning("There are no known neighbors to get key %s", key)
return None return None
spider = ValueSpiderCrawl(self.protocol, node, nearest, spider = ValueSpiderCrawl(self.protocol, node, nearest,
@ -172,8 +173,8 @@ class Server:
""" """
node = Node(dkey) node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node) nearest = self.protocol.router.find_neighbors(node)
if len(nearest) == 0: if not nearest:
log.warning("There are no known neighbors to set key %s", log.warning("There are no known neighbors to set key %s",
dkey.hex()) dkey.hex())
return False return False
@ -184,14 +185,14 @@ class Server:
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes))) log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
# if this node is close too, then store here as well # if this node is close too, then store here as well
biggest = max([n.distanceTo(node) for n in nodes]) biggest = max([n.distance_to(node) for n in nodes])
if self.node.distanceTo(node) < biggest: if self.node.distance_to(node) < biggest:
self.storage[dkey] = value self.storage[dkey] = value
ds = [self.protocol.callStore(n, dkey, value) for n in nodes] results = [self.protocol.call_store(n, dkey, value) for n in nodes]
# return true only if at least one store call succeeded # return true only if at least one store call succeeded
return any(await asyncio.gather(*ds)) return any(await asyncio.gather(*results))
def saveState(self, fname): def save_state(self, fname):
""" """
Save the state of this node (the alpha/ksize/id/immediate neighbors) Save the state of this node (the alpha/ksize/id/immediate neighbors)
to a cache file with the given fname. to a cache file with the given fname.
@ -201,29 +202,29 @@ class Server:
'ksize': self.ksize, 'ksize': self.ksize,
'alpha': self.alpha, 'alpha': self.alpha,
'id': self.node.id, 'id': self.node.id,
'neighbors': self.bootstrappableNeighbors() 'neighbors': self.bootstrappable_neighbors()
} }
if len(data['neighbors']) == 0: if not data['neighbors']:
log.warning("No known neighbors, so not writing to cache.") log.warning("No known neighbors, so not writing to cache.")
return return
with open(fname, 'wb') as f: with open(fname, 'wb') as file:
pickle.dump(data, f) pickle.dump(data, file)
@classmethod @classmethod
def loadState(self, fname): def load_state(cls, fname):
""" """
Load the state of this node (the alpha/ksize/id/immediate neighbors) Load the state of this node (the alpha/ksize/id/immediate neighbors)
from a cache file with the given fname. from a cache file with the given fname.
""" """
log.info("Loading state from %s", fname) log.info("Loading state from %s", fname)
with open(fname, 'rb') as f: with open(fname, 'rb') as file:
data = pickle.load(f) data = pickle.load(file)
s = Server(data['ksize'], data['alpha'], data['id']) svr = Server(data['ksize'], data['alpha'], data['id'])
if len(data['neighbors']) > 0: if data['neighbors']:
s.bootstrap(data['neighbors']) svr.bootstrap(data['neighbors'])
return s return svr
def saveStateRegularly(self, fname, frequency=600): def save_state_regularly(self, fname, frequency=600):
""" """
Save the state of node with a given regularity to the given Save the state of node with a given regularity to the given
filename. filename.
@ -233,10 +234,10 @@ class Server:
frequency: Frequency in seconds that the state should be saved. frequency: Frequency in seconds that the state should be saved.
By default, 10 minutes. By default, 10 minutes.
""" """
self.saveState(fname) self.save_state(fname)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.save_state_loop = loop.call_later(frequency, self.save_state_loop = loop.call_later(frequency,
self.saveStateRegularly, self.save_state_regularly,
fname, fname,
frequency) frequency)
@ -246,13 +247,11 @@ def check_dht_value_type(value):
Checks to see if the type of the value is a valid type for Checks to see if the type of the value is a valid type for
placing in the dht. placing in the dht.
""" """
typeset = set( typeset = [
[ int,
int, float,
float, bool,
bool, str,
str, bytes
bytes, ]
] return type(value) in typeset # pylint: disable=unidiomatic-typecheck
)
return type(value) in typeset

View File

@ -4,15 +4,15 @@ import heapq
class Node: class Node:
def __init__(self, node_id, ip=None, port=None): def __init__(self, node_id, ip=None, port=None):
self.id = node_id self.id = node_id # pylint: disable=invalid-name
self.ip = ip self.ip = ip # pylint: disable=invalid-name
self.port = port self.port = port
self.long_id = int(node_id.hex(), 16) self.long_id = int(node_id.hex(), 16)
def sameHomeAs(self, node): def same_home_as(self, node):
return self.ip == node.ip and self.port == node.port return self.ip == node.ip and self.port == node.port
def distanceTo(self, node): def distance_to(self, node):
""" """
Get the distance between this node and another. Get the distance between this node and another.
""" """
@ -47,7 +47,7 @@ class NodeHeap:
self.contacted = set() self.contacted = set()
self.maxsize = maxsize self.maxsize = maxsize
def remove(self, peerIDs): def remove(self, peers):
""" """
Remove a list of peer ids from this heap. Note that while this Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's heap retains a constant visible size (based on the iterator), it's
@ -55,34 +55,32 @@ class NodeHeap:
removal of nodes may not change the visible size as previously added removal of nodes may not change the visible size as previously added
nodes suddenly become visible. nodes suddenly become visible.
""" """
peerIDs = set(peerIDs) peers = set(peers)
if len(peerIDs) == 0: if not peers:
return return
nheap = [] nheap = []
for distance, node in self.heap: for distance, node in self.heap:
if node.id not in peerIDs: if node.id not in peers:
heapq.heappush(nheap, (distance, node)) heapq.heappush(nheap, (distance, node))
self.heap = nheap self.heap = nheap
def getNodeById(self, node_id): def get_node(self, node_id):
for _, node in self.heap: for _, node in self.heap:
if node.id == node_id: if node.id == node_id:
return node return node
return None return None
def allBeenContacted(self): def have_contacted_all(self):
return len(self.getUncontacted()) == 0 return len(self.get_uncontacted()) == 0
def getIDs(self): def get_ids(self):
return [n.id for n in self] return [n.id for n in self]
def markContacted(self, node): def mark_contacted(self, node):
self.contacted.add(node.id) self.contacted.add(node.id)
def popleft(self): def popleft(self):
if len(self) > 0: return heapq.heappop(self.heap)[1] if self else None
return heapq.heappop(self.heap)[1]
return None
def push(self, nodes): def push(self, nodes):
""" """
@ -95,7 +93,7 @@ class NodeHeap:
for node in nodes: for node in nodes:
if node not in self: if node not in self:
distance = self.node.distanceTo(node) distance = self.node.distance_to(node)
heapq.heappush(self.heap, (distance, node)) heapq.heappush(self.heap, (distance, node))
def __len__(self): def __len__(self):
@ -106,10 +104,10 @@ class NodeHeap:
return iter(map(itemgetter(1), nodes)) return iter(map(itemgetter(1), nodes))
def __contains__(self, node): def __contains__(self, node):
for _, n in self.heap: for _, other in self.heap:
if node.id == n.id: if node.id == other.id:
return True return True
return False return False
def getUncontacted(self): def get_uncontacted(self):
return [n for n in self if n.id not in self.contacted] return [n for n in self if n.id not in self.contacted]

View File

@ -4,41 +4,41 @@ import logging
from rpcudp.protocol import RPCProtocol from rpcudp.protocol import RPCProtocol
from .kademlia.node import Node from .node import Node
from .kademlia.routing import RoutingTable from .routing import RoutingTable
from .kademlia.utils import digest from .utils import digest
log = logging.getLogger(__name__) log = logging.getLogger(__name__) # pylint: disable=invalid-name
class KademliaProtocol(RPCProtocol): class KademliaProtocol(RPCProtocol):
def __init__(self, sourceNode, storage, ksize): def __init__(self, source_node, storage, ksize):
RPCProtocol.__init__(self) RPCProtocol.__init__(self)
self.router = RoutingTable(self, ksize, sourceNode) self.router = RoutingTable(self, ksize, source_node)
self.storage = storage self.storage = storage
self.sourceNode = sourceNode self.source_node = source_node
def getRefreshIDs(self): def get_refresh_ids(self):
""" """
Get ids to search for to keep old buckets up to date. Get ids to search for to keep old buckets up to date.
""" """
ids = [] ids = []
for bucket in self.router.getLonelyBuckets(): for bucket in self.router.lonely_buckets():
rid = random.randint(*bucket.range).to_bytes(20, byteorder='big') rid = random.randint(*bucket.range).to_bytes(20, byteorder='big')
ids.append(rid) ids.append(rid)
return ids return ids
def rpc_stun(self, sender): def rpc_stun(self, sender): # pylint: disable=no-self-use
return sender return sender
def rpc_ping(self, sender, nodeid): def rpc_ping(self, sender, nodeid):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
return self.sourceNode.id return self.source_node.id
def rpc_store(self, sender, nodeid, key, value): def rpc_store(self, sender, nodeid, key, value):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
log.debug("got a store request from %s, storing '%s'='%s'", log.debug("got a store request from %s, storing '%s'='%s'",
sender, key.hex(), value) sender, key.hex(), value)
self.storage[key] = value self.storage[key] = value
@ -48,42 +48,42 @@ class KademliaProtocol(RPCProtocol):
log.info("finding neighbors of %i in local table", log.info("finding neighbors of %i in local table",
int(nodeid.hex(), 16)) int(nodeid.hex(), 16))
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
node = Node(key) node = Node(key)
neighbors = self.router.findNeighbors(node, exclude=source) neighbors = self.router.find_neighbors(node, exclude=source)
return list(map(tuple, neighbors)) return list(map(tuple, neighbors))
def rpc_find_value(self, sender, nodeid, key): def rpc_find_value(self, sender, nodeid, key):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
value = self.storage.get(key, None) value = self.storage.get(key, None)
if value is None: if value is None:
return self.rpc_find_node(sender, nodeid, key) return self.rpc_find_node(sender, nodeid, key)
return {'value': value} return {'value': value}
async def callFindNode(self, nodeToAsk, nodeToFind): async def call_find_node(self, node_to_ask, node_to_find):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_node(address, self.sourceNode.id, result = await self.find_node(address, self.source_node.id,
nodeToFind.id) node_to_find.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callFindValue(self, nodeToAsk, nodeToFind): async def call_find_value(self, node_to_ask, node_to_find):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_value(address, self.sourceNode.id, result = await self.find_value(address, self.source_node.id,
nodeToFind.id) node_to_find.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callPing(self, nodeToAsk): async def call_ping(self, node_to_ask):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.ping(address, self.sourceNode.id) result = await self.ping(address, self.source_node.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callStore(self, nodeToAsk, key, value): async def call_store(self, node_to_ask, key, value):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.store(address, self.sourceNode.id, key, value) result = await self.store(address, self.source_node.id, key, value)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
def welcomeIfNewNode(self, node): def welcome_if_new(self, node):
""" """
Given a new node, send it all the keys/values it should be storing, Given a new node, send it all the keys/values it should be storing,
then add it to the routing table. then add it to the routing table.
@ -97,32 +97,32 @@ class KademliaProtocol(RPCProtocol):
is closer than the closest in that list, then store the key/value is closer than the closest in that list, then store the key/value
on the new node (per section 2.5 of the paper) on the new node (per section 2.5 of the paper)
""" """
if not self.router.isNewNode(node): if not self.router.is_new_node(node):
return return
log.info("never seen %s before, adding to router", node) log.info("never seen %s before, adding to router", node)
for key, value in self.storage.items(): for key, value in self.storage:
keynode = Node(digest(key)) keynode = Node(digest(key))
neighbors = self.router.findNeighbors(keynode) neighbors = self.router.find_neighbors(keynode)
if len(neighbors) > 0: if neighbors:
last = neighbors[-1].distanceTo(keynode) last = neighbors[-1].distance_to(keynode)
newNodeClose = node.distanceTo(keynode) < last new_node_close = node.distance_to(keynode) < last
first = neighbors[0].distanceTo(keynode) first = neighbors[0].distance_to(keynode)
thisNodeClosest = self.sourceNode.distanceTo(keynode) < first this_closest = self.source_node.distance_to(keynode) < first
if len(neighbors) == 0 or (newNodeClose and thisNodeClosest): if not neighbors or (new_node_close and this_closest):
asyncio.ensure_future(self.callStore(node, key, value)) asyncio.ensure_future(self.call_store(node, key, value))
self.router.addContact(node) self.router.add_contact(node)
def handleCallResponse(self, result, node): def handle_call_response(self, result, node):
""" """
If we get a response, add the node to the routing table. If If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table. we get no response, make sure it's removed from the routing table.
""" """
if not result[0]: if not result[0]:
log.warning("no response from %s, removing from router", node) log.warning("no response from %s, removing from router", node)
self.router.removeContact(node) self.router.remove_contact(node)
return result return result
log.info("got successful response from %s", node) log.info("got successful response from %s", node)
self.welcomeIfNewNode(node) self.welcome_if_new(node)
return result return result

View File

@ -4,22 +4,21 @@ import operator
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from .utils import OrderedSet, shared_prefix, bytes_to_bit_string
from .kademlia.utils import OrderedSet, sharedPrefix, bytesToBitString
class KBucket: class KBucket:
def __init__(self, rangeLower, rangeUpper, ksize): def __init__(self, rangeLower, rangeUpper, ksize):
self.range = (rangeLower, rangeUpper) self.range = (rangeLower, rangeUpper)
self.nodes = OrderedDict() self.nodes = OrderedDict()
self.replacementNodes = OrderedSet() self.replacement_nodes = OrderedSet()
self.touchLastUpdated() self.touch_last_updated()
self.ksize = ksize self.ksize = ksize
def touchLastUpdated(self): def touch_last_updated(self):
self.lastUpdated = time.monotonic() self.last_updated = time.monotonic()
def getNodes(self): def get_nodes(self):
return list(self.nodes.values()) return list(self.nodes.values())
def split(self): def split(self):
@ -31,23 +30,23 @@ class KBucket:
bucket.nodes[node.id] = node bucket.nodes[node.id] = node
return (one, two) return (one, two)
def removeNode(self, node): def remove_node(self, node):
if node.id not in self.nodes: if node.id not in self.nodes:
return return
# delete node, and see if we can add a replacement # delete node, and see if we can add a replacement
del self.nodes[node.id] del self.nodes[node.id]
if len(self.replacementNodes) > 0: if self.replacement_nodes:
newnode = self.replacementNodes.pop() newnode = self.replacement_nodes.pop()
self.nodes[newnode.id] = newnode self.nodes[newnode.id] = newnode
def hasInRange(self, node): def has_in_range(self, node):
return self.range[0] <= node.long_id <= self.range[1] return self.range[0] <= node.long_id <= self.range[1]
def isNewNode(self, node): def is_new_node(self, node):
return node.id not in self.nodes return node.id not in self.nodes
def addNode(self, node): def add_node(self, node):
""" """
Add a C{Node} to the C{KBucket}. Return True if successful, Add a C{Node} to the C{KBucket}. Return True if successful,
False if the bucket is full. False if the bucket is full.
@ -61,14 +60,14 @@ class KBucket:
elif len(self) < self.ksize: elif len(self) < self.ksize:
self.nodes[node.id] = node self.nodes[node.id] = node
else: else:
self.replacementNodes.push(node) self.replacement_nodes.push(node)
return False return False
return True return True
def depth(self): def depth(self):
vals = self.nodes.values() vals = self.nodes.values()
sp = sharedPrefix([bytesToBitString(n.id) for n in vals]) sprefix = shared_prefix([bytes_to_bit_string(n.id) for n in vals])
return len(sp) return len(sprefix)
def head(self): def head(self):
return list(self.nodes.values())[0] return list(self.nodes.values())[0]
@ -82,11 +81,11 @@ class KBucket:
class TableTraverser: class TableTraverser:
def __init__(self, table, startNode): def __init__(self, table, startNode):
index = table.getBucketFor(startNode) index = table.get_bucket_for(startNode)
table.buckets[index].touchLastUpdated() table.buckets[index].touch_last_updated()
self.currentNodes = table.buckets[index].getNodes() self.current_nodes = table.buckets[index].get_nodes()
self.leftBuckets = table.buckets[:index] self.left_buckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index + 1):] self.right_buckets = table.buckets[(index + 1):]
self.left = True self.left = True
def __iter__(self): def __iter__(self):
@ -96,16 +95,16 @@ class TableTraverser:
""" """
Pop an item from the left subtree, then right, then left, etc. Pop an item from the left subtree, then right, then left, etc.
""" """
if len(self.currentNodes) > 0: if self.current_nodes:
return self.currentNodes.pop() return self.current_nodes.pop()
if self.left and len(self.leftBuckets) > 0: if self.left and self.left_buckets:
self.currentNodes = self.leftBuckets.pop().getNodes() self.current_nodes = self.left_buckets.pop().get_nodes()
self.left = False self.left = False
return next(self) return next(self)
if len(self.rightBuckets) > 0: if self.right_buckets:
self.currentNodes = self.rightBuckets.pop(0).getNodes() self.current_nodes = self.right_buckets.pop(0).get_nodes()
self.left = True self.left = True
return next(self) return next(self)
@ -127,58 +126,60 @@ class RoutingTable:
def flush(self): def flush(self):
self.buckets = [KBucket(0, 2 ** 160, self.ksize)] self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
def splitBucket(self, index): def split_bucket(self, index):
one, two = self.buckets[index].split() one, two = self.buckets[index].split()
self.buckets[index] = one self.buckets[index] = one
self.buckets.insert(index + 1, two) self.buckets.insert(index + 1, two)
def getLonelyBuckets(self): def lonely_buckets(self):
""" """
Get all of the buckets that haven't been updated in over Get all of the buckets that haven't been updated in over
an hour. an hour.
""" """
hrago = time.monotonic() - 3600 hrago = time.monotonic() - 3600
return [b for b in self.buckets if b.lastUpdated < hrago] return [b for b in self.buckets if b.last_updated < hrago]
def removeContact(self, node): def remove_contact(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
self.buckets[index].removeNode(node) self.buckets[index].remove_node(node)
def isNewNode(self, node): def is_new_node(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
return self.buckets[index].isNewNode(node) return self.buckets[index].is_new_node(node)
def addContact(self, node): def add_contact(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
bucket = self.buckets[index] bucket = self.buckets[index]
# this will succeed unless the bucket is full # this will succeed unless the bucket is full
if bucket.addNode(node): if bucket.add_node(node):
return return
# Per section 4.2 of paper, split if the bucket has the node # Per section 4.2 of paper, split if the bucket has the node
# in its range or if the depth is not congruent to 0 mod 5 # in its range or if the depth is not congruent to 0 mod 5
if bucket.hasInRange(self.node) or bucket.depth() % 5 != 0: if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0:
self.splitBucket(index) self.split_bucket(index)
self.addContact(node) self.add_contact(node)
else: else:
asyncio.ensure_future(self.protocol.callPing(bucket.head())) asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
def getBucketFor(self, node): def get_bucket_for(self, node):
""" """
Get the index of the bucket that the given node would fall into. Get the index of the bucket that the given node would fall into.
""" """
for index, bucket in enumerate(self.buckets): for index, bucket in enumerate(self.buckets):
if node.long_id < bucket.range[1]: if node.long_id < bucket.range[1]:
return index return index
# we should never be here, but make linter happy
return None
def findNeighbors(self, node, k=None, exclude=None): def find_neighbors(self, node, k=None, exclude=None):
k = k or self.ksize k = k or self.ksize
nodes = [] nodes = []
for neighbor in TableTraverser(self, node): for neighbor in TableTraverser(self, node):
notexcluded = exclude is None or not neighbor.sameHomeAs(exclude) notexcluded = exclude is None or not neighbor.same_home_as(exclude)
if neighbor.id != node.id and notexcluded: if neighbor.id != node.id and notexcluded:
heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor)) heapq.heappush(nodes, (node.distance_to(neighbor), neighbor))
if len(nodes) == k: if len(nodes) == k:
break break

View File

@ -2,44 +2,45 @@ import time
from itertools import takewhile from itertools import takewhile
import operator import operator
from collections import OrderedDict from collections import OrderedDict
from abc import abstractmethod, ABC
class IStorage: class IStorage(ABC):
""" """
Local storage for this node. Local storage for this node.
IStorage implementations of get must return the same type as put in by set IStorage implementations of get must return the same type as put in by set
""" """
@abstractmethod
def __setitem__(self, key, value): def __setitem__(self, key, value):
""" """
Set a key to the given value. Set a key to the given value.
""" """
raise NotImplementedError
@abstractmethod
def __getitem__(self, key): def __getitem__(self, key):
""" """
Get the given key. If item doesn't exist, raises C{KeyError} Get the given key. If item doesn't exist, raises C{KeyError}
""" """
raise NotImplementedError
@abstractmethod
def get(self, key, default=None): def get(self, key, default=None):
""" """
Get given key. If not found, return default. Get given key. If not found, return default.
""" """
raise NotImplementedError
def iteritemsOlderThan(self, secondsOld): @abstractmethod
def iter_older_than(self, seconds_old):
""" """
Return the an iterator over (key, value) tuples for items older Return the an iterator over (key, value) tuples for items older
than the given secondsOld. than the given secondsOld.
""" """
raise NotImplementedError
@abstractmethod
def __iter__(self): def __iter__(self):
""" """
Get the iterator for this storage, should yield tuple of (key, value) Get the iterator for this storage, should yield tuple of (key, value)
""" """
raise NotImplementedError
class ForgetfulStorage(IStorage): class ForgetfulStorage(IStorage):
@ -57,7 +58,7 @@ class ForgetfulStorage(IStorage):
self.cull() self.cull()
def cull(self): def cull(self):
for _, _ in self.iteritemsOlderThan(self.ttl): for _, _ in self.iter_older_than(self.ttl):
self.data.popitem(last=False) self.data.popitem(last=False)
def get(self, key, default=None): def get(self, key, default=None):
@ -70,27 +71,23 @@ class ForgetfulStorage(IStorage):
self.cull() self.cull()
return self.data[key][1] return self.data[key][1]
def __iter__(self):
self.cull()
return iter(self.data)
def __repr__(self): def __repr__(self):
self.cull() self.cull()
return repr(self.data) return repr(self.data)
def iteritemsOlderThan(self, secondsOld): def iter_older_than(self, seconds_old):
minBirthday = time.monotonic() - secondsOld min_birthday = time.monotonic() - seconds_old
zipped = self._tripleIterable() zipped = self._triple_iter()
matches = takewhile(lambda r: minBirthday >= r[1], zipped) matches = takewhile(lambda r: min_birthday >= r[1], zipped)
return list(map(operator.itemgetter(0, 2), matches)) return list(map(operator.itemgetter(0, 2), matches))
def _tripleIterable(self): def _triple_iter(self):
ikeys = self.data.keys() ikeys = self.data.keys()
ibirthday = map(operator.itemgetter(0), self.data.values()) ibirthday = map(operator.itemgetter(0), self.data.values())
ivalues = map(operator.itemgetter(1), self.data.values()) ivalues = map(operator.itemgetter(1), self.data.values())
return zip(ikeys, ibirthday, ivalues) return zip(ikeys, ibirthday, ivalues)
def items(self): def __iter__(self):
self.cull() self.cull()
ikeys = self.data.keys() ikeys = self.data.keys()
ivalues = map(operator.itemgetter(1), self.data.values()) ivalues = map(operator.itemgetter(1), self.data.values())

View File

@ -6,16 +6,16 @@ import operator
import asyncio import asyncio
async def gather_dict(d): async def gather_dict(dic):
cors = list(d.values()) cors = list(dic.values())
results = await asyncio.gather(*cors) results = await asyncio.gather(*cors)
return dict(zip(d.keys(), results)) return dict(zip(dic.keys(), results))
def digest(s): def digest(string):
if not isinstance(s, bytes): if not isinstance(string, bytes):
s = str(s).encode('utf8') string = str(string).encode('utf8')
return hashlib.sha1(s).digest() return hashlib.sha1(string).digest()
class OrderedSet(list): class OrderedSet(list):
@ -34,7 +34,7 @@ class OrderedSet(list):
self.append(thing) self.append(thing)
def sharedPrefix(args): def shared_prefix(args):
""" """
Find the shared prefix between the strings. Find the shared prefix between the strings.
@ -52,6 +52,6 @@ def sharedPrefix(args):
return args[0][:i] return args[0][:i]
def bytesToBitString(bites): def bytes_to_bit_string(bites):
bits = [bin(bite)[2:].rjust(8, '0') for bite in bites] bits = [bin(bite)[2:].rjust(8, '0') for bite in bites]
return "".join(bits) return "".join(bits)