algorithm-in-python/dataStructure/splayTree.py

215 lines
5.8 KiB
Python
Raw Normal View History

2018-07-08 23:28:29 +08:00
''' mbinary
#########################################################################
# File : splayTree.py
# Author: mbinary
# Mail: zhuheqin1@gmail.com
2019-01-31 12:09:46 +08:00
# Blog: https://mbinary.xyz
2018-07-08 23:28:29 +08:00
# Github: https://github.com/mbinary
# Created Time: 2018-05-19 23:06
# Description:
#########################################################################
'''
2020-04-15 12:28:20 +08:00
from collections import deque, Iterable
2018-07-08 23:28:29 +08:00
# use isinstance(obj,Iterable) to judge if an obj is iterable
2020-04-15 12:28:20 +08:00
2018-07-08 23:28:29 +08:00
class node:
2020-04-15 12:28:20 +08:00
def __init__(self, val=None, left=None, right=None, parent=None):
2018-07-08 23:28:29 +08:00
self.val = val
2020-04-15 12:28:20 +08:00
if val:
self.freq = 1
else:
self.freq = 0
2018-07-08 23:28:29 +08:00
self.left = left
self.right = right
2020-04-15 12:28:20 +08:00
self.parent = parent
def getChild(self, s=0):
if isinstance(s, int):
s = [s]
2018-07-08 23:28:29 +08:00
last = self
for i in s:
2020-04-15 12:28:20 +08:00
if not last:
return None
if i == 0:
last = last.left
else:
last = last.right
2018-07-08 23:28:29 +08:00
return last
2020-04-15 12:28:20 +08:00
def setChild(self, child, s=0):
if isinstance(s, Iterable):
2018-07-08 23:28:29 +08:00
i = s[0]
del s[0]
2020-04-15 12:28:20 +08:00
if i == 0:
self.left.setChild(child, s)
else:
self.right.setChild(child, s)
elif s:
self.right = child
else:
self.left = child
2018-07-08 23:28:29 +08:00
class splayTree:
2020-04-15 12:28:20 +08:00
def __init__(self, s=[]):
2018-07-08 23:28:29 +08:00
s = list(s)
self.root = None
2020-04-15 12:28:20 +08:00
s = sorted(s, reverse=True)
2018-07-08 23:28:29 +08:00
for i in s:
2020-04-15 12:28:20 +08:00
self.insert(self.root, i)
def insert(self, k):
if not self.root:
self.root = node(k)
else:
self._insert(self.root, k)
def _insert(self, root, k):
if root.val == k:
root.freq += 1
elif root.val < k:
2018-07-08 23:28:29 +08:00
if not root.right:
root.right = node(k)
root.right.parent = root
2020-04-15 12:28:20 +08:00
else:
self._insert(root.right, k)
2018-07-08 23:28:29 +08:00
else:
if not root.left:
root.left = node(k)
root.left.parent = root
2020-04-15 12:28:20 +08:00
else:
self._insert(root.left, k)
def _zigzagRotate(self, i, j, root, parent, grand):
parent.setChild(root.getChild(i), j)
root.setChild(parent, i)
grand.setChild(root.getChild(j), i)
root.setChild(grand, j)
if root.parent:
root.parent = grand.parent
2018-07-08 23:28:29 +08:00
parent.parent = root
grand.parent = root
2020-04-15 12:28:20 +08:00
def _lineRotate(self, i, root, parent, grand):
grand.setChild(parent.getChild(i ^ 1), i)
parent.setChild(grand, i ^ 1)
parent.setChild(root.getChild(i ^ 1), i)
root.setChild(parent, i ^ 1)
if root.parent:
root.parent = grand.parent
parent.parent = root
2018-07-08 23:28:29 +08:00
grand.parent = parent
2020-04-15 12:28:20 +08:00
def _rotate(self, root):
if root == self.root:
return
2018-07-08 23:28:29 +08:00
if root.parent == self.root:
for i in range(2):
if root.parent.getChild(i) == root:
root.parent.parent = root
2020-04-15 12:28:20 +08:00
root.parent.setChild(root.getChild(i ^ 1), i)
2018-07-08 23:28:29 +08:00
root.parent = None
2020-04-15 12:28:20 +08:00
root.setChild(self.root, i ^ 1)
2018-07-08 23:28:29 +08:00
self.root = root
else:
grand = root.parent.parent
parent = root.parent
if grand == self.root:
2020-04-15 12:28:20 +08:00
self.root = root
root.parent = None
2018-07-08 23:28:29 +08:00
else:
for i in range(2):
2020-04-15 12:28:20 +08:00
if grand.parent.getChild(i) == grand:
grand.parent.setChild(root, i)
2018-07-08 23:28:29 +08:00
for i in range(2):
for j in range(2):
2020-04-15 12:28:20 +08:00
if i != j and grand.getChild([i, j]) == root:
self._zigzagRotate(i, j, root, parent, grand)
elif i == j and grand.getChild([i, i]) == root:
self._lineRotate(i, root, parent, grand)
2018-07-08 23:28:29 +08:00
self._rotate(root)
2020-04-15 12:28:20 +08:00
def _find(self, root, k):
if not root:
return 0
2018-07-08 23:28:29 +08:00
if root.val > k:
2020-04-15 12:28:20 +08:00
return self._find(root.left, k)
elif root.val < k:
return self._find(root.right, k)
2018-07-08 23:28:29 +08:00
else:
self._rotate(root)
return root.freq
2020-04-15 12:28:20 +08:00
def _maxmin(self, root, i=0):
if not root:
return None
2018-07-08 23:28:29 +08:00
if root.getChild(i):
return self._maxmin(root.getChild(i))
return root
2020-04-15 12:28:20 +08:00
2018-07-08 23:28:29 +08:00
def Max(self):
2020-04-15 12:28:20 +08:00
return self._maxmin(self.root, 1)
2018-07-08 23:28:29 +08:00
def Min(self):
2020-04-15 12:28:20 +08:00
return self._maxmin(self.root, 0)
def remove(self, k):
2018-07-08 23:28:29 +08:00
tmp = self.find(k)
2020-04-15 12:28:20 +08:00
if not tmp:
raise ValueError
2018-07-08 23:28:29 +08:00
else:
if self.root.left:
r = self.root.right
self.root = self.root.left
self.root.parent = None
Max = self.Max()
2020-04-15 12:28:20 +08:00
Max.right = r
2018-07-08 23:28:29 +08:00
if r:
r.parent = Max
else:
self.root = self.root.right
2020-04-15 12:28:20 +08:00
def find(self, k):
return self._find(self.root, k)
2018-07-08 23:28:29 +08:00
def levelTraverse(self):
q = deque()
2020-04-15 12:28:20 +08:00
q.append((self.root, 0))
2018-07-08 23:28:29 +08:00
rst = []
while q:
2020-04-15 12:28:20 +08:00
tmp, n = q.popleft()
2018-07-08 23:28:29 +08:00
rst.append(tmp)
2020-04-15 12:28:20 +08:00
if tmp.left:
q.append((tmp.left, n+1))
if tmp.right:
q.append((tmp.right, n+1))
2018-07-08 23:28:29 +08:00
return rst
2020-04-15 12:28:20 +08:00
2018-07-08 23:28:29 +08:00
def display(self):
data = self.levelTraverse()
for i in data:
2020-04-15 12:28:20 +08:00
print(i.val, end=' ')
2018-07-08 23:28:29 +08:00
print('')
2020-04-15 12:28:20 +08:00
2018-07-08 23:28:29 +08:00
if __name__ == '__main__':
a = splayTree()
a.insert(5)
a.insert(1)
a.insert(4)
a.insert(3)
a.insert(2)
a.insert(7)
a.insert(8)
a.insert(2)
print('initial:5,1,4,2,7,8,2')
a.display()
tmp = a.find(2)
2020-04-15 12:28:20 +08:00
print("after find(2):%d" % tmp)
2018-07-08 23:28:29 +08:00
a.display()
print("remove(4)")
a.remove(4)
a.display()