Created
January 21, 2025 02:40
-
-
Save xiabingquan/e385deb63bfdc2ab3dbbe7efe1bbf9cf to your computer and use it in GitHub Desktop.
A minimal (not fully tested) example of Trie Tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from copy import deepcopy | |
from typing import List, Optional | |
class TrieTree(object): | |
# Char-level Trie Tree | |
def __init__(self, is_end: bool, val: str = '', parent: Optional[TrieTree] = None): | |
""" | |
A Trie Tree | |
Args: | |
is_end: a flag to indicate whether the current node is an end of some string. | |
This attribute is particularly useful when this tree contains a word A and another | |
word with prefix A (e.g. "app" and "apple") during searching. | |
val: the value of current node. | |
parent: the parent of current node. | |
""" | |
self.is_end = is_end | |
self.val = val | |
self.parent = parent | |
self.children: List[TrieTree] = [] | |
def __repr__(self, layer_idx: int = 0): | |
c = ' ' * 2 * layer_idx + f"{self.val}{int(self.is_end)}" | |
if len(self.children) != 0: | |
c += '\n' | |
c += '\n'.join([child.__repr__(layer_idx + 1) for child in self.children]) | |
return c | |
@property | |
def is_root(self) -> bool: | |
"""Whether current node is the root node""" | |
return self.parent is None | |
@property | |
def is_leaf(self) -> bool: | |
"""Whether current node is a leaf node""" | |
return len(self.children) == 0 | |
@property | |
def is_dummy(self): | |
"""Useful when deleting nodes""" | |
return self.is_leaf and self.val is None | |
def mark_as_dummy(self): | |
"""Mark current node as a leaf node""" | |
assert all(child.is_dummy for child in self.children), self.children | |
self.val = None | |
self.children.clear() | |
def __getitem__(self, item): | |
return self.children[item] | |
def _check_word(self, word: List[str]) -> bool: | |
return len(word) > 0 and isinstance(word, list) and all(isinstance(w, str) for w in word) | |
def insert(self, word: List[str]) -> None: | |
""" | |
Insert a word into the Trie Tree | |
Args: | |
word: e.g. ['a', 'p', 'p', 'l', 'e'] | |
Returns: | |
""" | |
assert self._check_word(word) | |
# Base case: If the word only has one char, try insert to the current node or directly return | |
if len(word) == 1: | |
w = word[0] | |
for c in self.children: | |
if c.val == w: | |
if not c.is_end: # mark the non-ended node as ended | |
c.is_end = True | |
return | |
self.children.append(TrieTree(is_end=True, val=w, parent=self)) | |
return | |
# Recurse in: If currrent node has children, try to find a common prefix and insert the word to its children first | |
for child in self.children: | |
if child.val == word[0]: # find a common prefix, insert nodes to this child | |
child.insert(word[1:]) | |
return | |
# Recurse out: Otherwise, if current node has no child, insert to current node directly | |
node = TrieTree(is_end=False, val=word[0], parent=self) | |
self.children.append(node) | |
node.insert(word[1:]) # insert remained parts of the word | |
def search(self, word: List[str], force_end: bool = True) -> Optional[TrieTree]: | |
""" | |
Search whether the given word is in the Trie Tree | |
Args: | |
word: e.g. ['a', 'p', 'p', 'l', 'e'] | |
force_end: whether to force the returned node to be a end node. Useful for searching prefixes | |
Returns: the end node of the input word, will be None if the word is not found. | |
""" | |
assert self._check_word(word) | |
# breakpoint() | |
# Base case: If the word only has one char, ... | |
if len(word) == 1: | |
if self.is_root: | |
for c in self.children: | |
if c.val == word[0] and (c.is_end or not force_end): | |
return c | |
else: | |
if self.val == word[0] and (self.is_end or not force_end): | |
return self | |
return None | |
else: | |
if self.is_root: | |
pass | |
else: | |
if self.val != word[0]: # search failed | |
return None | |
else: | |
word.pop(0) # mark the first letter as matched and continue searching | |
# Recurse in: Otherwise, try to find a common prefix and search into the children | |
for child in self.children: | |
if child.val == word[0]: | |
return child.search(word, force_end) | |
# Recurse out: no common prefix or no children, searching fails | |
return None | |
def __recursive_delete(self, word: List[str]): | |
"""Recursively delete a word bottom to up. This is used to delete a string""" | |
assert isinstance(word, list) and all(isinstance(w, str) for w in word) | |
# Base case: The word has no char, remove dummy children and return True | |
if len(word) == 0: | |
for c in self.children: | |
if c.is_dummy: | |
self.children.remove(c) | |
return True | |
# Recurse in: if current node is a leaf or has all dummy children, mark it as a dummy, pop a char and recurse up | |
assert word[-1] == self.val | |
if self.is_leaf or all(c.is_dummy for c in self.children): # Recurse bottom up to remove nodes | |
word.pop(-1) | |
self.mark_as_dummy() | |
return self.parent.__recursive_delete(word) | |
else: # Has occupied children, stop recurse and collect dummy children | |
for c in self.children: | |
if c.is_dummy: | |
self.children.remove(c) | |
return True | |
def delete(self, word: List[str]) -> bool: | |
if len(word) == 0: | |
return True | |
assert self._check_word(word) | |
leaf = self.search(deepcopy(word), True) # use deepcopy to prevent in-place change of `word` in `self.search` | |
if leaf is None: # fail to find the corresponding leaf node | |
return False | |
return leaf.__recursive_delete(word) | |
def gather_postfixs(self) -> List[List[str]]: | |
# Base case: current node is a leaf node | |
if self.is_leaf: | |
return [[self.val]] | |
# Recurse in: append postfixes of children to self.val | |
postfixes = [] | |
for c in self.children: | |
for p in c.gather_postfixs(): | |
postfixes.append([self.val] + p) | |
if c.is_end and (not c.is_leaf): | |
postfixes.append([self.val] + [c.val]) | |
return postfixes | |
class PrefixMatcher(object): | |
def __init__(self, words: List[str]): | |
self.trie = build_trietree(words) | |
def __repr__(self): | |
return f'trie: \n-----\n{self.trie}\n-----' | |
def match(self, prefix: str) -> Optional[List[str]]: | |
node = self.trie.search(list(prefix), force_end=False) | |
return None if node is None else [prefix[:-1] + ''.join(a) for a in node.gather_postfixs()] | |
def build_trietree(words: List[str]) -> TrieTree: | |
root = TrieTree(is_end=False) | |
for w in words: | |
root.insert(list(w)) | |
return root | |
def test(): | |
words = 'aaa aac bsd rofnd rod aaadd aaade'.split(' ') | |
node = build_trietree(words) | |
for word in words: | |
print(f"haved: {word=} -> {node.search(list(word))}") | |
missing_words = "aa aad aaaddd rod1".split(' ') | |
for word in missing_words: | |
print(f"missing: {word=} -> {node.search(list(word))}") | |
node.search(list('aa'), force_end=False) | |
node.delete(list('aaadd')) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some improvements provided by AI chatbots, mainly focused on using for loop to replace recursion (not tested)