Last active
June 5, 2022 21:08
-
-
Save LiutongZhou/f67bf1a1546a996531e51ccb81898d5b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""UnionFind (Disjoint Sets)""" | |
from typing import Optional, Iterable, Hashable, Any | |
class UnionFind: | |
def __init__( | |
self, initial_disjoint_items: Optional[Iterable[Hashable]] = None | |
): | |
"""Initialize a UnionFind of disjoint sets""" | |
# maps from element to its root parent | |
self.parent = ( | |
{u: u for u in initial_disjoint_items} if initial_disjoint_items else {} | |
) | |
self._size = dict.fromkeys(self.parent, 1) # size of each disjoint set | |
self.num_sets = len(self.parent) | |
def add(self, u: Hashable): | |
"""Add an isolated item as a new disjoint set""" | |
if u not in self.parent: | |
self.parent[u] = u | |
self.num_sets += 1 | |
self._size[u] = 1 | |
def find(self, u: Hashable) -> Any: | |
"""Returns the root parent of the set that element u belongs to""" | |
parent, size, _u = self.parent, self._size, u | |
assert u in parent, f"{u} has not been added yet" | |
# find root parent | |
while (pu := parent[_u]) != _u: | |
_u = pu | |
_u = u | |
# Path compression | |
while _u != pu: | |
size.pop(_u, None) | |
_u, parent[_u] = parent[_u], pu | |
return pu | |
def union(self, u: Hashable, v: Hashable): | |
"""Union two disjoint sets if u and v are in disjoint sets""" | |
pu, pv = self.find(u), self.find(v) | |
size = self._size | |
if pu != pv: # skip if u and v are in same set already | |
if size[pu] < size[pv]: # Merge u set to v set | |
self.parent[pu] = pv | |
size[pv] += size[pu] | |
size.pop(pu, None) | |
else: | |
self.parent[pv] = pu | |
size[pu] += size[pv] | |
size.pop(pv, None) | |
self.num_sets -= 1 | |
def is_connected(self, u: Hashable, v: Hashable) -> bool: | |
"""Return True if u is connected with v else False""" | |
return self.find(u) == self.find(v) | |
def get_set_size(self, u: Hashable) -> int: | |
"""Return size of the disjoint set that u belongs to""" | |
return self._size[self.find(u)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment