#!/usr/bin/env python import sys # RBT != 仙踪林 # https://en.wikipedia.org/wiki/Red%E2%80%93black_tree class RBNode(object): def __init__(self, val, is_red=False, left=None, right=None, parent=None): self.val = val self.is_red = is_red self.left = left self.right = right self.parent = parent def red(self): return self.is_red def black(self): return not self.is_red def left_black(self): return True if self.left is None or self.left.black() else False def right_black(self): return True if self.right is None or self.right.black() else False def set_red(self): self.is_red = True def set_black(self): self.is_red = False @property def grandparent(self): if self.parent is not None: return self.parent.parent else: return None @property def uncle(self): gp = self.grandparent if gp is not None: return gp.right if self.parent == gp.left else gp.left return None @property def sibling(self): p = self.parent if p is not None: return p.left if self == p.right else p.right return None def rotate_left(self): p = self.parent assert self.right is not None n = self.right self.right, n.left = n.left, self n.parent = p self.parent = n if self.right is not None: self.right.parent = self if p is not None: if p.left == self: p.left = n else: p.right = n return n def __repr__(self): tail = '(R)' if self.is_red else '(B)' return str(self.val) + tail def check_me(self): bs = 1 if self.is_red: assert (not self.left or not self.left.is_red) assert (not self.right or not self.right.is_red) bs -= 1 bl = 0 if self.left is not None: assert self.left.val <= self.val bl = self.left.check_me() br = 0 if self.right is not None: assert self.right.val >= self.val br = self.right.check_me() assert bl == br return bl + bs def print_me(self): q = [self] level = 0 while True: new_q = [] cnt = 0 for x in q: if x is not None: if x.left is not None: cnt += 1 assert x.left.parent == x if x.right is not None: cnt += 1 assert x.right.parent == x new_q.append(x.left) new_q.append(x.right) else: new_q.append(None) new_q.append(None) print(level, [(str(e) if e is not None else '*') for e in q]) level += 1 if cnt == 0: break else: q = new_q def rotate_right(self): p = self.parent assert self.left is not None n = self.left self.left, n.right = n.right, self n.parent = p self.parent = n if self.left is not None: self.left.parent = self if p is not None: if p.left == self: p.left = n else: p.right = n return n class RBTree(object): def __init__(self): self._root = None def contains(self, val): n, _ = self._find_node(val) return n is not None def insert(self, val): n, p = self._find_node(val) if n is not None: return False if p is None: self._root = RBNode(val) else: if val < p.val: p.left = RBNode(val, is_red=True, parent=p) n = p.left else: p.right = RBNode(val, is_red=True, parent=p) n = p.right self._fix_insertions(n) return True def remove(self, val): n, _ = self._find_node(val) if n is None: return False if n.right is not None: t = n.right while t.left is not None: t = t.left n.val = t.val elif n.left is not None: t = n.left while t.right is not None: t = t. right n.val = t.val else: t = n self._fix_removals(t) # remove child = t.left if t.left is not None else t.right if child is not None: child.parent = t.parent if t.parent is not None: if t.parent.left == t: t.parent.left = child else: t.parent.right = child if self._root == t: self._root = None return True def _find_node(self, val): n = self._root p = None while n is not None: if val < n.val: p = n n = n.left elif val > n.val: p = n n = n.right else: return (n, p) return (None, p) def _fix_insertions(self, n): # insert red while n.parent is not None and n.grandparent is not None: if not (n.is_red and n.parent.is_red): break; if n.uncle is not None and n.uncle.is_red: n.grandparent.is_red = True n.parent.is_red = False n.uncle.is_red = False n = n.grandparent else: if n.parent == n.grandparent.left and n == n.parent.right: n, _ = n.parent, n.parent.rotate_left() elif n.parent == n.grandparent.right and n == n.parent.left: n, _ = n.parent, n.parent.rotate_right() else: n.parent.is_red = False n.grandparent.is_red = True if n.parent == n.grandparent.left: n = n.grandparent.rotate_right() else: n = n.grandparent.rotate_left() if n.parent is None: self._root = n if self._root.is_red: self._root.is_red = False def _fix_removals(self, n): if n.is_red: return # removing black # 1 - short circuit child = n.left if n.left is not None else n.right if child is not None and child.is_red: child.is_red = False return while n.parent is not None: # 2 if n.sibling is not None and n.sibling.is_red: n.parent.is_red = True n.sibling.is_red = False if n == n.parent.left: tmp = n.parent.rotate_left() else: tmp = n.parent.rotate_right() if tmp.parent is None: self._root = tmp # 3/4 if n.sibling is not None and not n.sibling.is_red: s = n.sibling if s.black() and s.left_black() and s.right_black(): s.set_red() if n.parent.black(): n = n.parent continue else: n.parent.set_black() break # 5 if n.sibling is not None and not n.sibling.is_red: s = n.sibling if s == n.parent.right and (s.left is not None and s.left.red()) and s.right_black(): s.left.set_black() s.set_red() s.rotate_right() elif s == n.parent.left and (s.right is not None and s.right.red()) and s.left_black(): s.right.set_black() s.set_red() s.rotate_left() # 6 s = n.sibling if s == n.parent.right and (s.right is not None and s.right.red()): s.is_red = n.parent.is_red s.right.set_black() n.parent.set_black() tmp = n.parent.rotate_left() if tmp.parent is None: self._root = tmp elif s == n.parent.left and (s.left is not None and s.left.red()): s.is_red = n.parent.is_red s.left.set_black() n.parent.set_black() tmp = n.parent.rotate_right() if tmp.parent is None: self._root = tmp break def print_me(self): if self._root is None: print(['*']) else: self._root.print_me() def check_me(self): if self._root is None: return assert not self._root.is_red self._root.check_me() if __name__ == '__main__': tree = RBTree() print(tree.contains(3)) import random nums = list(range(1, 37)) random.shuffle(nums) for i in nums: print('# add ', i) assert tree.insert(i) tree.print_me() tree.check_me() random.shuffle(nums) for i in nums: print('# rm ', i) assert tree.remove(i) tree.print_me() tree.check_me()