Last active
August 24, 2019 09:43
-
-
Save codefever/4442fd7130700c0b05fb1ca1ac2f3f4c to your computer and use it in GitHub Desktop.
RBT in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import sys | |
# RBT != 仙踪林 | |
# https://en.wikipedia.org/wiki/Red%E2%80%93black_tree | |
class RBNode(object): | |
def __init__(self, val, is_red=False, left=None, right=None, parent=None): | |
self.val = val | |
self.is_red = is_red | |
self.left = left | |
self.right = right | |
self.parent = parent | |
def red(self): | |
return self.is_red | |
def black(self): | |
return not self.is_red | |
def left_black(self): | |
return True if self.left is None or self.left.black() else False | |
def right_black(self): | |
return True if self.right is None or self.right.black() else False | |
def set_red(self): | |
self.is_red = True | |
def set_black(self): | |
self.is_red = False | |
@property | |
def grandparent(self): | |
if self.parent is not None: | |
return self.parent.parent | |
else: | |
return None | |
@property | |
def uncle(self): | |
gp = self.grandparent | |
if gp is not None: | |
return gp.right if self.parent == gp.left else gp.left | |
return None | |
@property | |
def sibling(self): | |
p = self.parent | |
if p is not None: | |
return p.left if self == p.right else p.right | |
return None | |
def rotate_left(self): | |
p = self.parent | |
assert self.right is not None | |
n = self.right | |
self.right, n.left = n.left, self | |
n.parent = p | |
self.parent = n | |
if self.right is not None: | |
self.right.parent = self | |
if p is not None: | |
if p.left == self: | |
p.left = n | |
else: | |
p.right = n | |
return n | |
def __repr__(self): | |
tail = '(R)' if self.is_red else '(B)' | |
return str(self.val) + tail | |
def check_me(self): | |
bs = 1 | |
if self.is_red: | |
assert (not self.left or not self.left.is_red) | |
assert (not self.right or not self.right.is_red) | |
bs -= 1 | |
bl = 0 | |
if self.left is not None: | |
assert self.left.val <= self.val | |
bl = self.left.check_me() | |
br = 0 | |
if self.right is not None: | |
assert self.right.val >= self.val | |
br = self.right.check_me() | |
assert bl == br | |
return bl + bs | |
def print_me(self): | |
q = [self] | |
level = 0 | |
while True: | |
new_q = [] | |
cnt = 0 | |
for x in q: | |
if x is not None: | |
if x.left is not None: | |
cnt += 1 | |
assert x.left.parent == x | |
if x.right is not None: | |
cnt += 1 | |
assert x.right.parent == x | |
new_q.append(x.left) | |
new_q.append(x.right) | |
else: | |
new_q.append(None) | |
new_q.append(None) | |
print(level, [(str(e) if e is not None else '*') for e in q]) | |
level += 1 | |
if cnt == 0: | |
break | |
else: | |
q = new_q | |
def rotate_right(self): | |
p = self.parent | |
assert self.left is not None | |
n = self.left | |
self.left, n.right = n.right, self | |
n.parent = p | |
self.parent = n | |
if self.left is not None: | |
self.left.parent = self | |
if p is not None: | |
if p.left == self: | |
p.left = n | |
else: | |
p.right = n | |
return n | |
class RBTree(object): | |
def __init__(self): | |
self._root = None | |
def contains(self, val): | |
n, _ = self._find_node(val) | |
return n is not None | |
def insert(self, val): | |
n, p = self._find_node(val) | |
if n is not None: | |
return False | |
if p is None: | |
self._root = RBNode(val) | |
else: | |
if val < p.val: | |
p.left = RBNode(val, is_red=True, parent=p) | |
n = p.left | |
else: | |
p.right = RBNode(val, is_red=True, parent=p) | |
n = p.right | |
self._fix_insertions(n) | |
return True | |
def remove(self, val): | |
n, _ = self._find_node(val) | |
if n is None: | |
return False | |
if n.right is not None: | |
t = n.right | |
while t.left is not None: | |
t = t.left | |
n.val = t.val | |
elif n.left is not None: | |
t = n.left | |
while t.right is not None: | |
t = t. right | |
n.val = t.val | |
else: | |
t = n | |
self._fix_removals(t) | |
# remove | |
child = t.left if t.left is not None else t.right | |
if child is not None: | |
child.parent = t.parent | |
if t.parent is not None: | |
if t.parent.left == t: | |
t.parent.left = child | |
else: | |
t.parent.right = child | |
if self._root == t: | |
self._root = None | |
return True | |
def _find_node(self, val): | |
n = self._root | |
p = None | |
while n is not None: | |
if val < n.val: | |
p = n | |
n = n.left | |
elif val > n.val: | |
p = n | |
n = n.right | |
else: | |
return (n, p) | |
return (None, p) | |
def _fix_insertions(self, n): | |
# insert red | |
while n.parent is not None and n.grandparent is not None: | |
if not (n.is_red and n.parent.is_red): | |
break; | |
if n.uncle is not None and n.uncle.is_red: | |
n.grandparent.is_red = True | |
n.parent.is_red = False | |
n.uncle.is_red = False | |
n = n.grandparent | |
else: | |
if n.parent == n.grandparent.left and n == n.parent.right: | |
n, _ = n.parent, n.parent.rotate_left() | |
elif n.parent == n.grandparent.right and n == n.parent.left: | |
n, _ = n.parent, n.parent.rotate_right() | |
else: | |
n.parent.is_red = False | |
n.grandparent.is_red = True | |
if n.parent == n.grandparent.left: | |
n = n.grandparent.rotate_right() | |
else: | |
n = n.grandparent.rotate_left() | |
if n.parent is None: | |
self._root = n | |
if self._root.is_red: | |
self._root.is_red = False | |
def _fix_removals(self, n): | |
if n.is_red: | |
return | |
# removing black | |
# 1 - short circuit | |
child = n.left if n.left is not None else n.right | |
if child is not None and child.is_red: | |
child.is_red = False | |
return | |
while n.parent is not None: | |
# 2 | |
if n.sibling is not None and n.sibling.is_red: | |
n.parent.is_red = True | |
n.sibling.is_red = False | |
if n == n.parent.left: | |
tmp = n.parent.rotate_left() | |
else: | |
tmp = n.parent.rotate_right() | |
if tmp.parent is None: | |
self._root = tmp | |
# 3/4 | |
if n.sibling is not None and not n.sibling.is_red: | |
s = n.sibling | |
if s.black() and s.left_black() and s.right_black(): | |
s.set_red() | |
if n.parent.black(): | |
n = n.parent | |
continue | |
else: | |
n.parent.set_black() | |
break | |
# 5 | |
if n.sibling is not None and not n.sibling.is_red: | |
s = n.sibling | |
if s == n.parent.right and (s.left is not None and s.left.red()) and s.right_black(): | |
s.left.set_black() | |
s.set_red() | |
s.rotate_right() | |
elif s == n.parent.left and (s.right is not None and s.right.red()) and s.left_black(): | |
s.right.set_black() | |
s.set_red() | |
s.rotate_left() | |
# 6 | |
s = n.sibling | |
if s == n.parent.right and (s.right is not None and s.right.red()): | |
s.is_red = n.parent.is_red | |
s.right.set_black() | |
n.parent.set_black() | |
tmp = n.parent.rotate_left() | |
if tmp.parent is None: | |
self._root = tmp | |
elif s == n.parent.left and (s.left is not None and s.left.red()): | |
s.is_red = n.parent.is_red | |
s.left.set_black() | |
n.parent.set_black() | |
tmp = n.parent.rotate_right() | |
if tmp.parent is None: | |
self._root = tmp | |
break | |
def print_me(self): | |
if self._root is None: | |
print(['*']) | |
else: | |
self._root.print_me() | |
def check_me(self): | |
if self._root is None: | |
return | |
assert not self._root.is_red | |
self._root.check_me() | |
if __name__ == '__main__': | |
tree = RBTree() | |
print(tree.contains(3)) | |
import random | |
nums = list(range(1, 37)) | |
random.shuffle(nums) | |
for i in nums: | |
print('# add ', i) | |
assert tree.insert(i) | |
tree.print_me() | |
tree.check_me() | |
random.shuffle(nums) | |
for i in nums: | |
print('# rm ', i) | |
assert tree.remove(i) | |
tree.print_me() | |
tree.check_me() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment