#!/usr/bin/env python

import sys

# RBT != 仙踪林
# https://en.wikipedia.org/wiki/Red%E2%80%93black_tree

class RBNode(object):
  def __init__(self, val, is_red=False, left=None, right=None, parent=None):
    self.val = val
    self.is_red = is_red
    self.left = left
    self.right = right
    self.parent = parent

  def red(self):
    return self.is_red
  def black(self):
    return not self.is_red
  def left_black(self):
    return True if self.left is None or self.left.black() else False
  def right_black(self):
    return True if self.right is None or self.right.black() else False
  def set_red(self):
    self.is_red = True
  def set_black(self):
    self.is_red = False

  @property
  def grandparent(self):
    if self.parent is not None:
      return self.parent.parent
    else:
      return None

  @property
  def uncle(self):
    gp = self.grandparent
    if gp is not None:
      return gp.right if self.parent == gp.left else gp.left
    return None

  @property
  def sibling(self):
    p = self.parent
    if p is not None:
      return p.left if self == p.right else p.right
    return None

  def rotate_left(self):
    p = self.parent
    assert self.right is not None
    n = self.right
    self.right, n.left = n.left, self
    n.parent = p
    self.parent = n
    if self.right is not None:
      self.right.parent = self
    if p is not None:
      if p.left == self:
        p.left = n
      else:
        p.right = n
    return n

  def __repr__(self):
    tail = '(R)' if self.is_red else '(B)'
    return str(self.val) + tail

  def check_me(self):
    bs = 1
    if self.is_red:
      assert (not self.left or not self.left.is_red)
      assert (not self.right or not self.right.is_red)
      bs -= 1
    bl = 0
    if self.left is not None:
      assert self.left.val <= self.val
      bl = self.left.check_me()
    br = 0
    if self.right is not None:
      assert self.right.val >= self.val
      br = self.right.check_me()
    assert bl == br
    return bl + bs

  def print_me(self):
    q = [self]
    level = 0
    while True:
      new_q = []
      cnt = 0
      for x in q:
        if x is not None:
          if x.left is not None:
            cnt += 1
            assert x.left.parent == x
          if x.right is not None:
            cnt += 1
            assert x.right.parent == x
          new_q.append(x.left)
          new_q.append(x.right)
        else:
          new_q.append(None)
          new_q.append(None)
      print(level, [(str(e) if e is not None else '*') for e in q])
      level += 1
      if cnt == 0:
        break
      else:
        q = new_q

  def rotate_right(self):
    p = self.parent
    assert self.left is not None
    n = self.left
    self.left, n.right = n.right, self
    n.parent = p
    self.parent = n
    if self.left is not None:
      self.left.parent = self
    if p is not None:
      if p.left == self:
        p.left = n
      else:
        p.right = n
    return n

class RBTree(object):
  def __init__(self):
    self._root = None

  def contains(self, val):
    n, _ = self._find_node(val)
    return n is not None

  def insert(self, val):
    n, p = self._find_node(val)
    if n is not None:
      return False
    if p is None:
      self._root = RBNode(val)
    else:
      if val < p.val:
        p.left = RBNode(val, is_red=True, parent=p)
        n = p.left
      else:
        p.right = RBNode(val, is_red=True, parent=p)
        n = p.right
      self._fix_insertions(n)
    return True

  def remove(self, val):
    n, _ = self._find_node(val)
    if n is None:
      return False
    if n.right is not None:
      t = n.right
      while t.left is not None:
        t = t.left
      n.val = t.val
    elif n.left is not None:
      t = n.left
      while t.right is not None:
        t = t. right
      n.val = t.val
    else:
      t = n
    self._fix_removals(t)

    # remove
    child = t.left if t.left is not None else t.right
    if child is not None:
      child.parent = t.parent
    if t.parent is not None:
      if t.parent.left == t:
        t.parent.left = child
      else:
        t.parent.right = child
    if self._root == t:
      self._root = None
    return True

  def _find_node(self, val):
    n = self._root
    p = None
    while n is not None:
      if val < n.val:
        p = n
        n = n.left
      elif val > n.val:
        p = n
        n = n.right
      else:
        return (n, p)
    return (None, p)

  def _fix_insertions(self, n):
    # insert red
    while n.parent is not None and n.grandparent is not None:
      if not (n.is_red and n.parent.is_red):
        break;
      if n.uncle is not None and n.uncle.is_red:
        n.grandparent.is_red = True
        n.parent.is_red = False
        n.uncle.is_red = False
        n = n.grandparent
      else:
        if n.parent == n.grandparent.left and n == n.parent.right:
          n, _ = n.parent, n.parent.rotate_left()
        elif n.parent == n.grandparent.right and n == n.parent.left:
          n, _ = n.parent, n.parent.rotate_right()
        else:
          n.parent.is_red = False
          n.grandparent.is_red = True
          if n.parent == n.grandparent.left:
            n = n.grandparent.rotate_right()
          else:
            n = n.grandparent.rotate_left()
          if n.parent is None:
            self._root = n
    if self._root.is_red:
      self._root.is_red = False

  def _fix_removals(self, n):
    if n.is_red:
      return

    # removing black
    # 1 - short circuit
    child = n.left if n.left is not None else n.right
    if child is not None and child.is_red:
      child.is_red = False
      return

    while n.parent is not None:
      # 2
      if n.sibling is not None and n.sibling.is_red:
        n.parent.is_red = True
        n.sibling.is_red = False
        if n == n.parent.left:
          tmp = n.parent.rotate_left()
        else:
          tmp = n.parent.rotate_right()
        if tmp.parent is None:
          self._root = tmp
      # 3/4
      if n.sibling is not None and not n.sibling.is_red:
        s = n.sibling
        if s.black() and s.left_black() and s.right_black():
          s.set_red()
          if n.parent.black():
            n = n.parent
            continue
          else:
            n.parent.set_black()
            break

      # 5
      if n.sibling is not None and not n.sibling.is_red:
        s = n.sibling
        if s == n.parent.right and (s.left is not None and s.left.red()) and s.right_black():
          s.left.set_black()
          s.set_red()
          s.rotate_right()
        elif s == n.parent.left and (s.right is not None and s.right.red()) and s.left_black():
          s.right.set_black()
          s.set_red()
          s.rotate_left()

        # 6
        s = n.sibling
        if s == n.parent.right and (s.right is not None and s.right.red()):
          s.is_red = n.parent.is_red
          s.right.set_black()
          n.parent.set_black()
          tmp = n.parent.rotate_left()
          if tmp.parent is None:
            self._root = tmp
        elif s == n.parent.left and (s.left is not None and s.left.red()):
          s.is_red = n.parent.is_red
          s.left.set_black()
          n.parent.set_black()
          tmp = n.parent.rotate_right()
          if tmp.parent is None:
            self._root = tmp
      break

  def print_me(self):
    if self._root is None:
      print(['*'])
    else:
      self._root.print_me()

  def check_me(self):
    if self._root is None:
      return
    assert not self._root.is_red
    self._root.check_me()

if __name__ == '__main__':
  tree = RBTree()
  print(tree.contains(3))
  import random
  nums = list(range(1, 37))
  random.shuffle(nums)
  for i in nums:
    print('# add ', i)
    assert tree.insert(i)
    tree.print_me()
    tree.check_me()

  random.shuffle(nums)
  for i in nums:
    print('# rm ', i)
    assert tree.remove(i)
    tree.print_me()
    tree.check_me()