Skip to content

Instantly share code, notes, and snippets.

@EssamWisam
Created April 1, 2025 21:26
Show Gist options
  • Save EssamWisam/80cc9cef6ba9616cc8b94e45b34de79b to your computer and use it in GitHub Desktop.
Save EssamWisam/80cc9cef6ba9616cc8b94e45b34de79b to your computer and use it in GitHub Desktop.
Trees
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In-order traversal: [20, 30, 40, 50, 60, 70, 80]\n",
"All tests passed!\n"
]
}
],
"source": [
"class Node:\n",
" \"\"\"Node class for Binary Search Tree.\"\"\"\n",
" def __init__(self, value):\n",
" self.value = value\n",
" self.left = None\n",
" self.right = None\n",
"\n",
"\n",
"class BinarySearchTree:\n",
" \"\"\"Binary Search Tree implementation with modular functions.\"\"\"\n",
" def __init__(self):\n",
" self.root = None\n",
" \n",
" def search(self, value):\n",
" \"\"\"Search for a value in the BST.\n",
" \n",
" Args:\n",
" value: The value to search for\n",
" \n",
" Returns:\n",
" tuple: (node, parent) if value is found, (None, potential_parent) otherwise\n",
" \"\"\"\n",
" current = self.root\n",
" parent = None\n",
" \n",
" while current and current.value != value:\n",
" parent = current\n",
" if value < current.value:\n",
" current = current.left\n",
" else:\n",
" current = current.right\n",
" \n",
" return current, parent\n",
" \n",
" def insert(self, value):\n",
" \"\"\"Insert a value into the BST.\n",
" \n",
" Args:\n",
" value: The value to insert\n",
" \"\"\"\n",
" # Case 1: Empty tree\n",
" if not self.root:\n",
" self.root = Node(value)\n",
" return\n",
" \n",
" # Use search to find the insertion point\n",
" node, parent = self.search(value)\n",
" \n",
" # Case 2: Value already exists\n",
" if node:\n",
" return # Value already exists, do nothing\n",
" \n",
" # Case 3: Insert as child\n",
" if value < parent.value:\n",
" parent.left = Node(value)\n",
" else:\n",
" parent.right = Node(value)\n",
" \n",
" def delete(self, value):\n",
" \"\"\"Delete a value from the BST.\n",
" \n",
" Args:\n",
" value: The value to delete\n",
" \n",
" Returns:\n",
" bool: True if deletion succeeded, False if value not found\n",
" \"\"\"\n",
" # Use search to find the node to delete and its parent\n",
" node, parent = self.search(value)\n",
" \n",
" # If value not found\n",
" if not node:\n",
" return False\n",
" \n",
" # Simplified deletion logic\n",
" # Handle the case when node has at most one child\n",
" if not node.left or not node.right:\n",
" # Choose the child that exists (or None if no children)\n",
" replacement = node.left if node.left else node.right\n",
" \n",
" # If it's the root\n",
" if not parent:\n",
" self.root = replacement\n",
" # If it's a left child\n",
" elif parent.left == node:\n",
" parent.left = replacement\n",
" # If it's a right child\n",
" else:\n",
" parent.right = replacement\n",
" \n",
" # Handle the case when node has two children\n",
" else:\n",
" # Find successor (minimum node in right subtree)\n",
" successor_parent = node\n",
" successor = node.right\n",
" \n",
" # Find the leftmost node in the right subtree\n",
" while successor.left:\n",
" successor_parent = successor\n",
" successor = successor.left\n",
" \n",
" # If successor is not the immediate right child\n",
" if successor_parent != node:\n",
" successor_parent.left = successor.right # connect parent of successor to the right child of successor (successor has one child right)\n",
" successor.right = node.right # assign right subtree\n",
" \n",
" successor.left = node.left\n",
" \n",
" # Connect the successor to the parent\n",
" if not parent:\n",
" self.root = successor\n",
" elif parent.left == node:\n",
" parent.left = successor\n",
" else:\n",
" parent.right = successor\n",
" \n",
" return True\n",
" \n",
" def inorder_traversal(self):\n",
" \"\"\"Perform in-order traversal and return the values.\n",
" \n",
" Returns:\n",
" list: Values in sorted order\n",
" \"\"\"\n",
" result = []\n",
" \n",
" def _inorder(node):\n",
" if node:\n",
" _inorder(node.left)\n",
" result.append(node.value)\n",
" _inorder(node.right)\n",
" \n",
" _inorder(self.root)\n",
" return result\n",
"\n",
"\n",
"# Test the implementation\n",
"def test_bst():\n",
" bst = BinarySearchTree()\n",
" \n",
" # Test insertion\n",
" values = [50, 30, 70, 20, 40, 60, 80]\n",
" for val in values:\n",
" bst.insert(val)\n",
" \n",
" # Test in-order traversal (should return sorted values)\n",
" print(\"In-order traversal:\", bst.inorder_traversal())\n",
" assert bst.inorder_traversal() == sorted(values)\n",
" \n",
" # Test search\n",
" found_node, _ = bst.search(40)\n",
" assert found_node and found_node.value == 40\n",
" \n",
" not_found_node, _ = bst.search(45)\n",
" assert not not_found_node\n",
" \n",
" # Test deletion\n",
" # Case 1: Delete a leaf node (20)\n",
" assert bst.delete(20) == True\n",
" assert 20 not in bst.inorder_traversal()\n",
" \n",
" # Case 2: Delete a node with one child (30)\n",
" assert bst.delete(30) == True\n",
" assert 30 not in bst.inorder_traversal()\n",
" \n",
" # Case 3: Delete a node with two children (50, the root)\n",
" assert bst.delete(50) == True\n",
" assert 50 not in bst.inorder_traversal()\n",
" \n",
" # Delete a non-existent value\n",
" assert bst.delete(100) == False\n",
" \n",
" print(\"All tests passed!\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" test_bst()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AVL in-order traversal after insertions: [20, 30, 40, 50, 60, 70, 80]\n",
"All AVL tests passed!\n"
]
}
],
"source": [
"class AVLNode(Node):\n",
" \"\"\"AVL Node that extends the BST Node with a height attribute.\"\"\"\n",
" def __init__(self, value):\n",
" super().__init__(value)\n",
" self.height = 1 # New nodes are initially at leaf level\n",
"\n",
"\n",
"class AVLTree(BinarySearchTree):\n",
" def __init__(self):\n",
" self.root = None # We'll use AVLNode for nodes\n",
"\n",
" def get_height(self, node):\n",
" if not node:\n",
" return 0\n",
" return node.height\n",
"\n",
" def update_height(self, node):\n",
" node.height = 1 + max(self.get_height(node.left), self.get_height(node.right))\n",
"\n",
" def get_balance(self, node):\n",
" if not node:\n",
" return 0\n",
" return self.get_height(node.left) - self.get_height(node.right)\n",
"\n",
" def rotate_right(self, y):\n",
" # Right rotation:\n",
" # y x\n",
" # / \\ / \\\n",
" # x T3 ==> T1 y\n",
" # / \\ / \\\n",
" #T1 T2 T2 T3\n",
" x = y.left\n",
" T2 = x.right\n",
"\n",
" # Perform rotation\n",
" x.right = y\n",
" y.left = T2\n",
"\n",
" # Update heights\n",
" self.update_height(y)\n",
" self.update_height(x)\n",
"\n",
" return x\n",
"\n",
" def rotate_left(self, x):\n",
" # Left rotation:\n",
" # x y\n",
" # / \\ / \\\n",
" # T1 y ==> x T3\n",
" # / \\ / \\\n",
" # T2 T3 T1 T2\n",
" y = x.right\n",
" T2 = y.left\n",
"\n",
" # Perform rotation\n",
" y.left = x\n",
" x.right = T2\n",
"\n",
" # Update heights\n",
" self.update_height(x)\n",
" self.update_height(y)\n",
"\n",
" return y\n",
"\n",
" def rebalance(self, node):\n",
" \"\"\"Check balance factor and perform rotations if needed.\"\"\"\n",
" balance = self.get_balance(node)\n",
" # Left heavy subtree\n",
" if balance > 1:\n",
" # Left-Left case: (balance of left child is >= 0)\n",
" if self.get_balance(node.left) >= 0:\n",
" # Intuitive: Right rotate at node.\n",
" return self.rotate_right(node)\n",
" # Left-Right case:\n",
" else:\n",
" # First, left rotate the left child, then right rotate the node.\n",
" node.left = self.rotate_left(node.left)\n",
" return self.rotate_right(node)\n",
" # Right heavy subtree\n",
" if balance < -1:\n",
" # Right-Right case: (balance of right child is <= 0)\n",
" if self.get_balance(node.right) <= 0:\n",
" # Intuitive: Left rotate at node.\n",
" return self.rotate_left(node)\n",
" # Right-Left case:\n",
" else:\n",
" # First, right rotate the right child, then left rotate the node.\n",
" node.right = self.rotate_right(node.right)\n",
" return self.rotate_left(node)\n",
" # Node is balanced\n",
" return node\n",
"\n",
" def _insert(self, node, value):\n",
" \"\"\"Recursive helper to insert value and rebalance the tree.\"\"\"\n",
" # Base case: found the spot for new node.\n",
" if not node:\n",
" return AVLNode(value)\n",
"\n",
" # BST insertion\n",
" if value < node.value:\n",
" node.left = self._insert(node.left, value)\n",
" elif value > node.value:\n",
" node.right = self._insert(node.right, value)\n",
" else:\n",
" # Duplicate values are not inserted.\n",
" return node\n",
"\n",
" # Update height of this ancestor node.\n",
" self.update_height(node)\n",
"\n",
" # Rebalance if needed:\n",
" node = self.rebalance(node)\n",
" return node\n",
"\n",
" def insert(self, value):\n",
" \"\"\"Insert value into AVL tree.\"\"\"\n",
" self.root = self._insert(self.root, value)\n",
"\n",
" def _min_value_node(self, node):\n",
" \"\"\"Get the node with the smallest value (leftmost leaf).\"\"\"\n",
" current = node\n",
" while current.left:\n",
" current = current.left\n",
" return current\n",
"\n",
" def _delete(self, node, value):\n",
" \"\"\"Recursive helper to delete a node and rebalance the tree.\"\"\"\n",
" if not node:\n",
" return node\n",
"\n",
" # Standard BST deletion.\n",
" if value < node.value:\n",
" node.left = self._delete(node.left, value)\n",
" elif value > node.value:\n",
" node.right = self._delete(node.right, value)\n",
" else:\n",
" # Node with the value found.\n",
" # Node with one or no child.\n",
" if not node.left or not node.right:\n",
" node = node.left if node.left else node.right\n",
" else:\n",
" # Node with two children: Get the in-order successor (smallest in the right subtree).\n",
" successor = self._min_value_node(node.right)\n",
" node.value = successor.value # Copy the value.\n",
" node.right = self._delete(node.right, successor.value)\n",
"\n",
" # If the tree had only one node then return.\n",
" if not node:\n",
" return node\n",
"\n",
" # Update the height of the current node.\n",
" self.update_height(node)\n",
"\n",
" # Rebalance the node if needed.\n",
" node = self.rebalance(node)\n",
" return node\n",
"\n",
" def delete(self, value):\n",
" \"\"\"Delete a value from the AVL tree.\"\"\"\n",
" # Use search inherited from BST to check if value exists\n",
" if self.search(value)[0] is None:\n",
" return False\n",
" self.root = self._delete(self.root, value)\n",
" return True\n",
"\n",
"def is_balanced(node, avl):\n",
" \"\"\"\n",
" Recursively check that the balance factor of each node is between -1 and 1.\n",
" \"\"\"\n",
" if node is None:\n",
" return True\n",
" balance = avl.get_balance(node)\n",
" if abs(balance) > 1:\n",
" return False\n",
" return is_balanced(node.left, avl) and is_balanced(node.right, avl)\n",
"\n",
"# Test the AVL implementation with balance property checks.\n",
"def test_avl():\n",
" avl = AVLTree()\n",
" values = [50, 30, 70, 20, 40, 60, 80]\n",
" for val in values:\n",
" avl.insert(val)\n",
" # Check balance property after each insertion\n",
" assert is_balanced(avl.root, avl), f\"Tree became unbalanced after inserting {val}\"\n",
"\n",
" print(\"AVL in-order traversal after insertions:\", avl.inorder_traversal())\n",
" assert avl.inorder_traversal() == sorted(values)\n",
"\n",
" # Test search using inherited BST search\n",
" found_node, _ = avl.search(40)\n",
" assert found_node and found_node.value == 40\n",
"\n",
" # Test deletion: Delete leaf node\n",
" assert avl.delete(20) == True\n",
" assert 20 not in avl.inorder_traversal()\n",
" assert is_balanced(avl.root, avl), \"Tree became unbalanced after deleting 20\"\n",
"\n",
" # Test deletion: Delete node with one child\n",
" assert avl.delete(30) == True\n",
" assert 30 not in avl.inorder_traversal()\n",
" assert is_balanced(avl.root, avl), \"Tree became unbalanced after deleting 30\"\n",
"\n",
" # Test deletion: Delete node with two children\n",
" assert avl.delete(50) == True\n",
" assert 50 not in avl.inorder_traversal()\n",
" assert is_balanced(avl.root, avl), \"Tree became unbalanced after deleting 50\"\n",
"\n",
" # Attempt to delete non-existent value\n",
" assert avl.delete(100) == False\n",
"\n",
" print(\"All AVL tests passed!\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" test_avl()\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RB in-order traversal after insertions: [20, 30, 40, 50, 60, 70, 80]\n",
"All Red-Black Tree tests passed!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"<>:74: SyntaxWarning: invalid escape sequence '\\ '\n",
"<>:101: SyntaxWarning: invalid escape sequence '\\ '\n",
"<>:74: SyntaxWarning: invalid escape sequence '\\ '\n",
"<>:101: SyntaxWarning: invalid escape sequence '\\ '\n",
"/var/folders/_4/h_zdrstn11z5jggdm7rz92p00000gn/T/ipykernel_16061/124133109.py:74: SyntaxWarning: invalid escape sequence '\\ '\n",
" \"\"\"\n",
"/var/folders/_4/h_zdrstn11z5jggdm7rz92p00000gn/T/ipykernel_16061/124133109.py:101: SyntaxWarning: invalid escape sequence '\\ '\n",
" \"\"\"\n"
]
}
],
"source": [
"class Node:\n",
" \"\"\"Basic Node for Binary Search Tree implementation.\"\"\"\n",
" def __init__(self, value):\n",
" self.value = value\n",
" self.left = None\n",
" self.right = None\n",
"\n",
"\n",
"class BinarySearchTree:\n",
" \"\"\"Basic Binary Search Tree implementation.\"\"\"\n",
" def __init__(self):\n",
" self.root = None\n",
"\n",
" def inorder_traversal(self):\n",
" \"\"\"Return list of values from in-order traversal.\"\"\"\n",
" result = []\n",
" self._inorder_traversal(self.root, result)\n",
" return result\n",
"\n",
" def _inorder_traversal(self, node, result):\n",
" \"\"\"Helper for in-order traversal.\"\"\"\n",
" if node:\n",
" self._inorder_traversal(node.left, result)\n",
" result.append(node.value)\n",
" self._inorder_traversal(node.right, result)\n",
"\n",
" def search(self, value):\n",
" \"\"\"Search for value, return (node, parent) tuple.\"\"\"\n",
" return self._search(self.root, value, None)\n",
"\n",
" def _search(self, node, value, parent):\n",
" \"\"\"Helper for search.\"\"\"\n",
" if not node or node.value == value:\n",
" return node, parent\n",
" if value < node.value:\n",
" return self._search(node.left, value, node)\n",
" return self._search(node.right, value, node)\n",
"\n",
"\n",
"class RBNode(Node):\n",
" \"\"\"Red-Black Tree Node that extends the BST Node with color.\"\"\"\n",
" RED = True\n",
" BLACK = False\n",
" \n",
" def __init__(self, value):\n",
" super().__init__(value)\n",
" self.color = RBNode.RED # New nodes are initially RED\n",
" self.parent = None # Track parent for easier rotations\n",
"\n",
"\n",
"class RedBlackTree(BinarySearchTree):\n",
" \"\"\"Red-Black Tree implementation inheriting from BinarySearchTree.\"\"\"\n",
" def __init__(self):\n",
" self.root = None # We'll use RBNode for nodes\n",
" self.NIL = RBNode(None) # Sentinel NIL node\n",
" self.NIL.color = RBNode.BLACK\n",
" self.NIL.left = None\n",
" self.NIL.right = None\n",
"\n",
" def inorder_traversal(self):\n",
" \"\"\"Return list of values from in-order traversal.\"\"\"\n",
" result = []\n",
" self._inorder_traversal(self.root, result)\n",
" return result\n",
"\n",
" def _inorder_traversal(self, node, result):\n",
" \"\"\"Helper for in-order traversal.\"\"\"\n",
" if node and node != self.NIL:\n",
" self._inorder_traversal(node.left, result)\n",
" result.append(node.value)\n",
" self._inorder_traversal(node.right, result)\n",
"\n",
" def rotate_left(self, x):\n",
" \"\"\"\n",
" Left rotation:\n",
" x y\n",
" / \\ / \\\n",
" T1 y ==> x T3\n",
" / \\ / \\\n",
" T2 T3 T1 T2\n",
" \"\"\"\n",
" y = x.right\n",
" x.right = y.left\n",
" \n",
" if y.left != self.NIL:\n",
" y.left.parent = x\n",
" \n",
" y.parent = x.parent\n",
" \n",
" if x.parent is None:\n",
" self.root = y\n",
" elif x == x.parent.left:\n",
" x.parent.left = y\n",
" else:\n",
" x.parent.right = y\n",
" \n",
" y.left = x\n",
" x.parent = y\n",
"\n",
" def rotate_right(self, y):\n",
" \"\"\"\n",
" Right rotation:\n",
" y x\n",
" / \\ / \\\n",
" x T3 ==> T1 y\n",
" / \\ / \\\n",
" T1 T2 T2 T3\n",
" \"\"\"\n",
" x = y.left\n",
" y.left = x.right\n",
" \n",
" if x.right != self.NIL:\n",
" x.right.parent = y\n",
" \n",
" x.parent = y.parent\n",
" \n",
" if y.parent is None:\n",
" self.root = x\n",
" elif y == y.parent.left:\n",
" y.parent.left = x\n",
" else:\n",
" y.parent.right = x\n",
" \n",
" x.right = y\n",
" y.parent = x\n",
"\n",
" def insert(self, value):\n",
" \"\"\"Insert value into Red-Black tree.\"\"\"\n",
" node = RBNode(value)\n",
" node.left = self.NIL\n",
" node.right = self.NIL\n",
" \n",
" y = None\n",
" x = self.root\n",
" \n",
" # Find the correct position to insert\n",
" while x != self.NIL and x is not None:\n",
" y = x\n",
" if node.value < x.value:\n",
" x = x.left\n",
" else:\n",
" x = x.right\n",
" \n",
" node.parent = y\n",
" \n",
" # If tree is empty, new node becomes root\n",
" if y is None:\n",
" self.root = node\n",
" self.root.color = RBNode.BLACK # Root is always black\n",
" return\n",
" \n",
" # Insert node based on its value\n",
" if node.value < y.value:\n",
" y.left = node\n",
" else:\n",
" y.right = node\n",
" \n",
" # If new node's parent is root, we're done\n",
" if node.parent.parent is None:\n",
" return\n",
" \n",
" # Fix Red-Black properties\n",
" self._fix_insert(node)\n",
"\n",
" def _fix_insert(self, k):\n",
" \"\"\"Fix Red-Black Tree properties after insertion.\"\"\"\n",
" # While we haven't reached the root and parent is red\n",
" while k != self.root and k.parent and k.parent.color == RBNode.RED:\n",
" # Parent is left child of grandparent\n",
" if k.parent == k.parent.parent.left:\n",
" uncle = k.parent.parent.right\n",
" \n",
" # Case 1: Uncle is red\n",
" if uncle and uncle.color == RBNode.RED:\n",
" k.parent.color = RBNode.BLACK\n",
" uncle.color = RBNode.BLACK\n",
" k.parent.parent.color = RBNode.RED\n",
" k = k.parent.parent\n",
" else:\n",
" # Case 2: k is right child\n",
" if k == k.parent.right:\n",
" k = k.parent\n",
" self.rotate_left(k)\n",
" \n",
" # Case 3: k is left child\n",
" if k.parent: # Check if parent exists\n",
" k.parent.color = RBNode.BLACK\n",
" if k.parent.parent: # Check if grandparent exists\n",
" k.parent.parent.color = RBNode.RED\n",
" self.rotate_right(k.parent.parent)\n",
" # Parent is right child of grandparent\n",
" else:\n",
" uncle = k.parent.parent.left\n",
" \n",
" # Case 1: Uncle is red\n",
" if uncle and uncle.color == RBNode.RED:\n",
" k.parent.color = RBNode.BLACK\n",
" uncle.color = RBNode.BLACK\n",
" k.parent.parent.color = RBNode.RED\n",
" k = k.parent.parent\n",
" else:\n",
" # Case 2: k is left child\n",
" if k == k.parent.left:\n",
" k = k.parent\n",
" self.rotate_right(k)\n",
" \n",
" # Case 3: k is right child\n",
" if k.parent: # Check if parent exists\n",
" k.parent.color = RBNode.BLACK\n",
" if k.parent.parent: # Check if grandparent exists\n",
" k.parent.parent.color = RBNode.RED\n",
" self.rotate_left(k.parent.parent)\n",
" \n",
" # Ensure root is black\n",
" self.root.color = RBNode.BLACK\n",
"\n",
" def _transplant(self, u, v):\n",
" \"\"\"Replace subtree rooted at u with subtree rooted at v.\"\"\"\n",
" if u.parent is None:\n",
" self.root = v\n",
" elif u == u.parent.left:\n",
" u.parent.left = v\n",
" else:\n",
" u.parent.right = v\n",
" \n",
" if v:\n",
" v.parent = u.parent\n",
"\n",
" def _min_value_node(self, node):\n",
" \"\"\"Get the node with minimum value in subtree rooted at node.\"\"\"\n",
" current = node\n",
" while current.left != self.NIL and current.left is not None:\n",
" current = current.left\n",
" return current\n",
"\n",
" def delete(self, value):\n",
" \"\"\"Delete a node with the given value.\"\"\"\n",
" # Find the node to delete\n",
" node, _ = self.search(value)\n",
" if node is None or node == self.NIL:\n",
" return False\n",
" \n",
" self._delete_node(node)\n",
" return True\n",
"\n",
" def _delete_node(self, z):\n",
" \"\"\"Delete node z from the tree.\"\"\"\n",
" y = z\n",
" y_original_color = y.color\n",
" \n",
" if z.left == self.NIL or z.left is None:\n",
" x = z.right if z.right is not None else self.NIL\n",
" self._transplant(z, x)\n",
" elif z.right == self.NIL or z.right is None:\n",
" x = z.left if z.left is not None else self.NIL\n",
" self._transplant(z, x)\n",
" else:\n",
" y = self._min_value_node(z.right)\n",
" y_original_color = y.color\n",
" x = y.right if y.right is not None else self.NIL\n",
" \n",
" if y.parent == z:\n",
" x.parent = y\n",
" else:\n",
" self._transplant(y, x)\n",
" y.right = z.right\n",
" y.right.parent = y\n",
" \n",
" self._transplant(z, y)\n",
" y.left = z.left\n",
" y.left.parent = y\n",
" y.color = z.color\n",
" \n",
" # If the removed node was BLACK, fix RB tree properties\n",
" if y_original_color == RBNode.BLACK:\n",
" self._fix_delete(x)\n",
"\n",
" def _fix_delete(self, x):\n",
" \"\"\"Fix Red-Black Tree properties after deletion.\"\"\"\n",
" while x != self.root and x and x.color == RBNode.BLACK:\n",
" if x == x.parent.left:\n",
" w = x.parent.right\n",
" \n",
" # Case 1: Sibling is red\n",
" if w.color == RBNode.RED:\n",
" w.color = RBNode.BLACK\n",
" x.parent.color = RBNode.RED\n",
" self.rotate_left(x.parent)\n",
" w = x.parent.right\n",
" \n",
" # Case 2: Both of sibling's children are black\n",
" if (w.left is None or w.left.color == RBNode.BLACK) and \\\n",
" (w.right is None or w.right.color == RBNode.BLACK):\n",
" w.color = RBNode.RED\n",
" x = x.parent\n",
" else:\n",
" # Case 3: Sibling's right child is black\n",
" if w.right is None or w.right.color == RBNode.BLACK:\n",
" if w.left:\n",
" w.left.color = RBNode.BLACK\n",
" w.color = RBNode.RED\n",
" self.rotate_right(w)\n",
" w = x.parent.right\n",
" \n",
" # Case 4: Sibling's right child is red\n",
" w.color = x.parent.color\n",
" x.parent.color = RBNode.BLACK\n",
" if w.right:\n",
" w.right.color = RBNode.BLACK\n",
" self.rotate_left(x.parent)\n",
" x = self.root\n",
" else:\n",
" w = x.parent.left\n",
" \n",
" # Case 1: Sibling is red\n",
" if w.color == RBNode.RED:\n",
" w.color = RBNode.BLACK\n",
" x.parent.color = RBNode.RED\n",
" self.rotate_right(x.parent)\n",
" w = x.parent.left\n",
" \n",
" # Case 2: Both of sibling's children are black\n",
" if (w.left is None or w.left.color == RBNode.BLACK) and \\\n",
" (w.right is None or w.right.color == RBNode.BLACK):\n",
" w.color = RBNode.RED\n",
" x = x.parent\n",
" else:\n",
" # Case 3: Sibling's left child is black\n",
" if w.left is None or w.left.color == RBNode.BLACK:\n",
" if w.right:\n",
" w.right.color = RBNode.BLACK\n",
" w.color = RBNode.RED\n",
" self.rotate_left(w)\n",
" w = x.parent.left\n",
" \n",
" # Case 4: Sibling's left child is red\n",
" w.color = x.parent.color\n",
" x.parent.color = RBNode.BLACK\n",
" if w.left:\n",
" w.left.color = RBNode.BLACK\n",
" self.rotate_right(x.parent)\n",
" x = self.root\n",
" \n",
" if x:\n",
" x.color = RBNode.BLACK\n",
"\n",
" def search(self, value):\n",
" \"\"\"Search for value, return (node, parent) tuple.\"\"\"\n",
" if not self.root:\n",
" return None, None\n",
" \n",
" return self._search(self.root, value, None)\n",
"\n",
" def _search(self, node, value, parent):\n",
" \"\"\"Helper for search.\"\"\"\n",
" if not node or node == self.NIL or node.value == value:\n",
" return node, parent\n",
" if value < node.value:\n",
" return self._search(node.left, value, node)\n",
" return self._search(node.right, value, node)\n",
"\n",
"\n",
"def is_valid_rb_tree(node, rb_tree, black_count=-1, path_black_count=0):\n",
" \"\"\"\n",
" Check if the tree satisfies all Red-Black properties:\n",
" 1. Every node is either red or black\n",
" 2. Root is black\n",
" 3. Every leaf (NIL) is black\n",
" 4. If a node is red, both its children are black\n",
" 5. All paths from root to leaves have same number of black nodes\n",
" \"\"\"\n",
" # Base case: NIL node is black\n",
" if node is None or node == rb_tree.NIL:\n",
" # First leaf dictates the black count; for other leaves, check if counts match\n",
" if black_count == -1:\n",
" return path_black_count, True\n",
" return black_count == path_black_count, True\n",
"\n",
" # Check property 4: Red nodes have black children\n",
" if node.color == RBNode.RED:\n",
" if (node.left and node.left != rb_tree.NIL and node.left.color == RBNode.RED) or \\\n",
" (node.right and node.right != rb_tree.NIL and node.right.color == RBNode.RED):\n",
" print(f\"Red node {node.value} has red child\")\n",
" return False, False\n",
"\n",
" # Count black nodes along path\n",
" current_black = path_black_count + (1 if node.color == RBNode.BLACK else 0)\n",
" \n",
" # Recursively check left and right subtrees\n",
" left_count, left_valid = is_valid_rb_tree(node.left, rb_tree, black_count, current_black)\n",
" if not left_valid:\n",
" return -1, False\n",
" \n",
" right_count, right_valid = is_valid_rb_tree(node.right, rb_tree, left_count, current_black)\n",
" if not right_valid:\n",
" return -1, False\n",
" \n",
" return left_count, True\n",
"\n",
"\n",
"def test_rb_tree():\n",
" rb = RedBlackTree()\n",
" values = [50, 30, 70, 20, 40, 60, 80]\n",
" \n",
" # Test insertion\n",
" for val in values:\n",
" rb.insert(val)\n",
" # Check RB properties after each insertion\n",
" _, is_valid = is_valid_rb_tree(rb.root, rb)\n",
" assert is_valid, f\"Tree became invalid after inserting {val}\"\n",
" assert rb.root.color == RBNode.BLACK, \"Root should always be black\"\n",
"\n",
" print(\"RB in-order traversal after insertions:\", rb.inorder_traversal())\n",
" assert rb.inorder_traversal() == sorted(values)\n",
"\n",
" # Test search\n",
" found_node, _ = rb.search(40)\n",
" assert found_node and found_node.value == 40\n",
"\n",
" # Test deletion: Delete leaf node\n",
" assert rb.delete(20) == True\n",
" assert 20 not in rb.inorder_traversal()\n",
" _, is_valid = is_valid_rb_tree(rb.root, rb)\n",
" assert is_valid, \"Tree became invalid after deleting 20\"\n",
"\n",
" # Test deletion: Delete node with one child\n",
" assert rb.delete(30) == True\n",
" assert 30 not in rb.inorder_traversal()\n",
" _, is_valid = is_valid_rb_tree(rb.root, rb)\n",
" assert is_valid, \"Tree became invalid after deleting 30\"\n",
"\n",
" # Test deletion: Delete node with two children\n",
" assert rb.delete(50) == True\n",
" assert 50 not in rb.inorder_traversal()\n",
" _, is_valid = is_valid_rb_tree(rb.root, rb)\n",
" assert is_valid, \"Tree became invalid after deleting 50\"\n",
"\n",
" # Attempt to delete non-existent value\n",
" assert rb.delete(100) == False\n",
"\n",
" print(\"All Red-Black Tree tests passed!\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" test_rb_tree()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "m1",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@EssamWisam
Copy link
Author

The escape sequence errors are just due to tree docs so can be ignored.

@EssamWisam
Copy link
Author

Query called search in the implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment