Skip to content

Instantly share code, notes, and snippets.

@JosephCatrambone
Created July 11, 2022 23:51

Revisions

  1. JosephCatrambone created this gist Jul 11, 2022.
    47 changes: 47 additions & 0 deletions union_find.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,47 @@
    class UnionFind:
    def __init__(self):
    self.parents = dict()

    def add(self, node_id:int):
    self.parents[node_id] = node_id

    def union(self, node_a:int, node_b:int):
    min_parent = min(self.parents[node_a], self.parents[node_b])
    self.parents[node_a] = min_parent
    self.parents[node_b] = min_parent

    def peek(self, node_id:int) -> int:
    """
    Traverse the list all the way to the top.
    Returns None if the given node_id is not in the set.
    Does not mutate the parent set.
    Use find instead.
    """
    parent = node_id
    while True:
    if parent not in self.parents:
    return None
    # If this is a loop, return self.
    if parent == self.parents[parent]:
    return parent
    parent = self.parents[parent]

    def find(self, node_id:int) -> int:
    """
    Traverse the parent tree and return the ultimate parent index of the item.
    Example: 1 -> 2 -> 3, 6 -> 7, 3 -> 7, 8 -> 9
    find(2) == 7, find(3) == 7, find(6) == 7 find(8) == 9
    Mutates and parent tree as it is traversed.
    """
    ultimate_parent = self.peek(node_id)
    if ultimate_parent is None:
    return None

    parent = node_id
    while True:
    if parent == self.parents[parent]:
    return parent # Loop!
    next_step = self.parents[parent]
    self.parents[parent] = ultimate_parent
    parent = next_step