algorithm-in-python/dataStructure/redBlackTree.py

363 lines
11 KiB
Python
Raw Normal View History

2018-10-02 21:24:06 +08:00
''' mbinary
#########################################################################
# File : redBlackTree.py
# Author: mbinary
# Mail: zhuheqin1@gmail.com
2019-01-31 12:09:46 +08:00
# Blog: https://mbinary.xyz
2018-10-02 21:24:06 +08:00
# Github: https://github.com/mbinary
# Created Time: 2018-07-14 16:15
# Description:
#########################################################################
'''
from functools import total_ordering
from random import randint, shuffle
2020-04-15 12:28:20 +08:00
@total_ordering
class node:
2020-04-15 12:28:20 +08:00
def __init__(self, val, left=None, right=None, isBlack=False):
self.val = val
self.left = left
self.right = right
2020-04-15 12:28:20 +08:00
self.parent = None
self.isBlack = isBlack
def __lt__(self, nd):
return self.val < nd.val
2020-04-15 12:28:20 +08:00
def __eq__(self, nd):
2018-07-14 17:55:01 +08:00
return nd is not None and self.val == nd.val
2020-04-15 12:28:20 +08:00
def setChild(self, nd, isLeft):
if isLeft:
self.left = nd
else:
self.right = nd
if nd is not None:
nd.parent = self
def getChild(self, isLeft):
if isLeft:
return self.left
else:
return self.right
def __bool__(self):
return self.val is not None
2020-04-15 12:28:20 +08:00
def __str__(self):
color = 'B' if self.isBlack else 'R'
2020-04-15 12:28:20 +08:00
val = '-' if self.parent == None else self.parent.val
return f'{color}-{self.val}'
2020-04-15 12:28:20 +08:00
def __repr__(self):
return f'node({self.val},isBlack={self.isBlack})'
2020-04-15 12:28:20 +08:00
class redBlackTree:
2020-04-15 12:28:20 +08:00
def __init__(self, unique=False):
2018-07-14 17:55:01 +08:00
'''if unique is True, all node'vals are unique, else there may be equal vals'''
self.root = None
2018-07-14 17:55:01 +08:00
self.unique = unique
@staticmethod
def checkBlack(nd):
return nd is None or nd.isBlack
2020-04-15 12:28:20 +08:00
2018-07-14 17:55:01 +08:00
@staticmethod
2020-04-15 12:28:20 +08:00
def setBlack(nd, isBlack):
2018-07-14 17:55:01 +08:00
if nd is not None:
if isBlack is None or isBlack:
nd.isBlack = True
2020-04-15 12:28:20 +08:00
else:
nd.isBlack = False
def setRoot(self, nd):
if nd is not None:
nd.parent = None
self.root = nd
def find(self, val):
nd = self.root
while nd:
2020-04-15 12:28:20 +08:00
if nd.val == val:
return nd
else:
2020-04-15 12:28:20 +08:00
nd = nd.getChild(nd.val > val)
def getSuccessor(self, nd):
2018-07-14 17:55:01 +08:00
if nd:
if nd.right:
nd = nd.right
while nd.left:
nd = nd.left
return nd
else:
while nd.parent is not None and nd.parent.right is nd:
nd = nd.parent
return None if nd is self.root else nd.parent
2020-04-15 12:28:20 +08:00
def rotate(self, prt, chd):
'''rotate prt with the center of chd'''
if self.root is prt:
self.setRoot(chd)
2018-07-14 17:55:01 +08:00
else:
prt.parent.setChild(chd, prt.parent.left is prt)
isLeftChd = prt.left is chd
prt.setChild(chd.getChild(not isLeftChd), isLeftChd)
2020-04-15 12:28:20 +08:00
chd.setChild(prt, not isLeftChd)
2018-07-14 17:55:01 +08:00
2020-04-15 12:28:20 +08:00
def insert(self, nd):
if nd.isBlack:
nd.isBlack = False
2018-07-14 17:55:01 +08:00
if self.root is None:
self.setRoot(nd)
2018-07-14 17:55:01 +08:00
self.root.isBlack = True
else:
parent = self.root
while parent:
2020-04-15 12:28:20 +08:00
if parent == nd:
return None
isLeft = parent > nd
2020-04-15 12:28:20 +08:00
chd = parent.getChild(isLeft)
if chd is None:
2020-04-15 12:28:20 +08:00
parent.setChild(nd, isLeft)
break
else:
parent = chd
2020-04-15 12:28:20 +08:00
self.fixUpInsert(parent, nd)
def fixUpInsert(self, parent, nd):
''' adjust color and level, there are two red nodes: the new one and its parent'''
while not self.checkBlack(parent):
grand = parent.parent
isLeftPrt = grand.left is parent
uncle = grand.getChild(not isLeftPrt)
if not self.checkBlack(uncle):
# case 1: new node's uncle is red
self.setBlack(grand, False)
self.setBlack(grand.left, True)
self.setBlack(grand.right, True)
nd = grand
parent = nd.parent
else:
# case 2: new node's uncle is black(including nil leaf)
2018-07-14 17:55:01 +08:00
isLeftNode = parent.left is nd
if isLeftNode ^ isLeftPrt:
# case 2.1 the new node is inserted in left-right or right-left form
# grand grand
# parent or parent
# nd nd
2020-04-15 12:28:20 +08:00
self.rotate(parent, nd) # parent rotate
nd, parent = parent, nd
# case 3 (case 2.2) the new node is inserted in left-left or right-right form
# grand grand
# parent or parent
# nd nd
self.setBlack(grand, False)
self.setBlack(parent, True)
2020-04-15 12:28:20 +08:00
self.rotate(grand, parent)
self.setBlack(self.root, True)
2020-04-15 12:28:20 +08:00
def copyNode(self, src, des):
2018-07-14 17:55:01 +08:00
'''when deleting a node which has two kids,
copy its succesor's data to his position
data exclude left, right , isBlack
'''
des.val = src.val
2020-04-15 12:28:20 +08:00
def delete(self, val):
2018-07-14 17:55:01 +08:00
'''delete node in a binary search tree'''
2020-04-15 12:28:20 +08:00
if isinstance(val, node):
val = val.val
nd = self.find(val)
2020-04-15 12:28:20 +08:00
if nd is None:
return
self._delete(nd)
2020-04-15 12:28:20 +08:00
def _delete(self, nd):
y = None
if nd.left and nd.right:
2020-04-15 12:28:20 +08:00
y = self.getSuccessor(nd)
else:
y = nd
py = y.parent
x = y.left if y.left else y.right
if py is None:
self.setRoot(x)
else:
2020-04-15 12:28:20 +08:00
py.setChild(x, py.left is y)
if y != nd:
2020-04-15 12:28:20 +08:00
self.copyNode(y, nd)
if self.checkBlack(y):
self.fixUpDel(py, x)
def fixUpDel(self, prt, chd):
2018-07-14 17:55:01 +08:00
''' adjust colors and rotate '''
while self.root != chd and self.checkBlack(chd):
2020-04-15 12:28:20 +08:00
isLeft = prt.left is chd
2018-07-14 17:55:01 +08:00
brother = prt.getChild(not isLeft)
# brother is black
lb = self.checkBlack(brother.getChild(isLeft))
rb = self.checkBlack(brother.getChild(not isLeft))
2020-04-15 12:28:20 +08:00
if not self.checkBlack(brother):
2018-07-14 17:55:01 +08:00
# case 1: brother is red. converted to case 2,3,4
2020-04-15 12:28:20 +08:00
self.setBlack(prt, False)
self.setBlack(brother, True)
self.rotate(prt, brother)
2018-07-14 17:55:01 +08:00
2020-04-15 12:28:20 +08:00
elif lb and rb:
# case 2: brother is black and two kids are black.
2018-07-14 17:55:01 +08:00
# conveted to the begin case
2020-04-15 12:28:20 +08:00
self.setBlack(brother, False)
chd = prt
2020-04-15 12:28:20 +08:00
prt = chd.parent
2018-07-14 17:55:01 +08:00
else:
2020-04-15 12:28:20 +08:00
if rb:
2018-07-14 17:55:01 +08:00
# case 3: brother is black and left kid is red and right child is black
# rotate bro to make g w wl wr in one line
# uncle's son is nephew, and niece for uncle's daughter
2018-07-14 17:55:01 +08:00
nephew = brother.getChild(isLeft)
2020-04-15 12:28:20 +08:00
self.setBlack(nephew, True)
self.setBlack(brother, False)
2018-07-14 17:55:01 +08:00
# brother (not isLeft) rotate
2020-04-15 12:28:20 +08:00
self.rotate(brother, nephew)
2018-07-14 17:55:01 +08:00
brother = nephew
2018-07-14 17:55:01 +08:00
# case 4: brother is black and right child is red
brother.isBlack = prt.isBlack
2020-04-15 12:28:20 +08:00
self.setBlack(prt, True)
self.setBlack(brother.getChild(not isLeft), True)
2020-04-15 12:28:20 +08:00
self.rotate(prt, brother)
2018-07-14 17:55:01 +08:00
chd = self.root
2020-04-15 12:28:20 +08:00
self.setBlack(chd, True)
2020-04-15 12:28:20 +08:00
def sort(self, reverse=False):
''' return a generator of sorted data'''
def inOrder(root):
2020-04-15 12:28:20 +08:00
if root is None:
return
if reverse:
yield from inOrder(root.right)
else:
yield from inOrder(root.left)
yield root
if reverse:
yield from inOrder(root.left)
else:
yield from inOrder(root.right)
yield from inOrder(self.root)
2018-07-14 17:55:01 +08:00
def display(self):
def getHeight(nd):
2020-04-15 12:28:20 +08:00
if nd is None:
return 0
return max(getHeight(nd.left), getHeight(nd.right)) + 1
def levelVisit(root):
from collections import deque
lst = deque([root])
level = []
h = getHeight(root)
2018-07-14 17:55:01 +08:00
ct = lv = 0
while 1:
2020-04-15 12:28:20 +08:00
ct += 1
nd = lst.popleft()
if ct >= 2**lv:
2020-04-15 12:28:20 +08:00
lv += 1
if lv > h:
break
level.append([])
level[-1].append(str(nd))
if nd is not None:
2020-04-15 12:28:20 +08:00
lst += [nd.left, nd.right]
else:
2020-04-15 12:28:20 +08:00
lst += [None, None]
return level
2020-04-15 12:28:20 +08:00
2018-07-14 17:55:01 +08:00
def addBlank(lines):
width = 1+len(str(self.root))
2018-07-14 17:55:01 +08:00
sep = ' '*width
n = len(lines)
2020-04-15 12:28:20 +08:00
for i, oneline in enumerate(lines):
k = 2**(n-i) - 1
2018-07-14 17:55:01 +08:00
new = [sep*((k-1)//2)]
for s in oneline:
new.append(s.ljust(width))
new.append(sep*k)
lines[i] = new
return lines
lines = levelVisit(self.root)
2018-07-14 17:55:01 +08:00
lines = addBlank(lines)
li = [''.join(line) for line in lines]
2020-04-15 12:28:20 +08:00
length = 10 if li == [] else max(len(i) for i in li)//2
begin = '\n' + 'red-black-tree'.rjust(length+14, '-') + '-'*(length)
end = '-'*(length*2+14)+'\n'
2020-04-15 12:28:20 +08:00
return '\n'.join([begin, *li, end])
def __str__(self):
return self.display()
2020-04-15 12:28:20 +08:00
def genNum(n=10):
nums = []
for i in range(n):
while 1:
2020-04-15 12:28:20 +08:00
d = randint(0, 100)
if d not in nums:
nums.append(d)
break
return nums
2020-04-15 12:28:20 +08:00
def buildTree(n=10, nums=None, visitor=None):
if nums is None or nums == []:
nums = genNum(n)
rbtree = redBlackTree()
print(f'build a red-black tree using {nums}')
for i in nums:
rbtree.insert(node(i))
print(rbtree)
if visitor:
2020-04-15 12:28:20 +08:00
visitor(rbtree, i)
return rbtree, nums
def testInsert(nums=None):
2020-04-15 12:28:20 +08:00
def visitor(t, val):
print('inserting', val)
print(t)
2020-04-15 12:28:20 +08:00
rbtree, nums = buildTree(visitor=visitor, nums=nums)
print('-'*5 + 'in-order visit' + '-'*5)
for i, j in enumerate(rbtree.sort()):
print(f'{i+1}: {j}')
2020-04-15 12:28:20 +08:00
def testSuc(nums=None):
2020-04-15 12:28:20 +08:00
rbtree, nums = buildTree(nums=nums)
for i in rbtree.sort():
print(f'{i}\'s suc is {rbtree.getSuccessor(i)}')
2020-04-15 12:28:20 +08:00
def testDelete(nums=None):
2020-04-15 12:28:20 +08:00
rbtree, nums = buildTree(nums=nums)
print(rbtree)
for i in sorted(nums):
print(f'deleting {i}')
rbtree.delete(i)
print(rbtree)
2020-04-15 12:28:20 +08:00
if __name__ == '__main__':
lst = [45, 30, 64, 36, 95, 38, 76, 34, 50, 1]
lst = [0, 3, 5, 6, 26, 25, 8, 19, 15, 16, 17]
# testSuc(lst)
# testInsert(lst)
2018-07-14 17:55:01 +08:00
testDelete()