Skip to content

Instantly share code, notes, and snippets.

@xiabingquan
Created January 21, 2025 02:40
Show Gist options
  • Save xiabingquan/e385deb63bfdc2ab3dbbe7efe1bbf9cf to your computer and use it in GitHub Desktop.
Save xiabingquan/e385deb63bfdc2ab3dbbe7efe1bbf9cf to your computer and use it in GitHub Desktop.
A minimal (not fully tested) example of Trie Tree
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
class TrieTree(object):
# Char-level Trie Tree
def __init__(self, is_end: bool, val: str = '', parent: Optional[TrieTree] = None):
"""
A Trie Tree
Args:
is_end: a flag to indicate whether the current node is an end of some string.
This attribute is particularly useful when this tree contains a word A and another
word with prefix A (e.g. "app" and "apple") during searching.
val: the value of current node.
parent: the parent of current node.
"""
self.is_end = is_end
self.val = val
self.parent = parent
self.children: List[TrieTree] = []
def __repr__(self, layer_idx: int = 0):
c = ' ' * 2 * layer_idx + f"{self.val}{int(self.is_end)}"
if len(self.children) != 0:
c += '\n'
c += '\n'.join([child.__repr__(layer_idx + 1) for child in self.children])
return c
@property
def is_root(self) -> bool:
"""Whether current node is the root node"""
return self.parent is None
@property
def is_leaf(self) -> bool:
"""Whether current node is a leaf node"""
return len(self.children) == 0
@property
def is_dummy(self):
"""Useful when deleting nodes"""
return self.is_leaf and self.val is None
def mark_as_dummy(self):
"""Mark current node as a leaf node"""
assert all(child.is_dummy for child in self.children), self.children
self.val = None
self.children.clear()
def __getitem__(self, item):
return self.children[item]
def _check_word(self, word: List[str]) -> bool:
return len(word) > 0 and isinstance(word, list) and all(isinstance(w, str) for w in word)
def insert(self, word: List[str]) -> None:
"""
Insert a word into the Trie Tree
Args:
word: e.g. ['a', 'p', 'p', 'l', 'e']
Returns:
"""
assert self._check_word(word)
# Base case: If the word only has one char, try insert to the current node or directly return
if len(word) == 1:
w = word[0]
for c in self.children:
if c.val == w:
if not c.is_end: # mark the non-ended node as ended
c.is_end = True
return
self.children.append(TrieTree(is_end=True, val=w, parent=self))
return
# Recurse in: If currrent node has children, try to find a common prefix and insert the word to its children first
for child in self.children:
if child.val == word[0]: # find a common prefix, insert nodes to this child
child.insert(word[1:])
return
# Recurse out: Otherwise, if current node has no child, insert to current node directly
node = TrieTree(is_end=False, val=word[0], parent=self)
self.children.append(node)
node.insert(word[1:]) # insert remained parts of the word
def search(self, word: List[str], force_end: bool = True) -> Optional[TrieTree]:
"""
Search whether the given word is in the Trie Tree
Args:
word: e.g. ['a', 'p', 'p', 'l', 'e']
force_end: whether to force the returned node to be a end node. Useful for searching prefixes
Returns: the end node of the input word, will be None if the word is not found.
"""
assert self._check_word(word)
# breakpoint()
# Base case: If the word only has one char, ...
if len(word) == 1:
if self.is_root:
for c in self.children:
if c.val == word[0] and (c.is_end or not force_end):
return c
else:
if self.val == word[0] and (self.is_end or not force_end):
return self
return None
else:
if self.is_root:
pass
else:
if self.val != word[0]: # search failed
return None
else:
word.pop(0) # mark the first letter as matched and continue searching
# Recurse in: Otherwise, try to find a common prefix and search into the children
for child in self.children:
if child.val == word[0]:
return child.search(word, force_end)
# Recurse out: no common prefix or no children, searching fails
return None
def __recursive_delete(self, word: List[str]):
"""Recursively delete a word bottom to up. This is used to delete a string"""
assert isinstance(word, list) and all(isinstance(w, str) for w in word)
# Base case: The word has no char, remove dummy children and return True
if len(word) == 0:
for c in self.children:
if c.is_dummy:
self.children.remove(c)
return True
# Recurse in: if current node is a leaf or has all dummy children, mark it as a dummy, pop a char and recurse up
assert word[-1] == self.val
if self.is_leaf or all(c.is_dummy for c in self.children): # Recurse bottom up to remove nodes
word.pop(-1)
self.mark_as_dummy()
return self.parent.__recursive_delete(word)
else: # Has occupied children, stop recurse and collect dummy children
for c in self.children:
if c.is_dummy:
self.children.remove(c)
return True
def delete(self, word: List[str]) -> bool:
if len(word) == 0:
return True
assert self._check_word(word)
leaf = self.search(deepcopy(word), True) # use deepcopy to prevent in-place change of `word` in `self.search`
if leaf is None: # fail to find the corresponding leaf node
return False
return leaf.__recursive_delete(word)
def gather_postfixs(self) -> List[List[str]]:
# Base case: current node is a leaf node
if self.is_leaf:
return [[self.val]]
# Recurse in: append postfixes of children to self.val
postfixes = []
for c in self.children:
for p in c.gather_postfixs():
postfixes.append([self.val] + p)
if c.is_end and (not c.is_leaf):
postfixes.append([self.val] + [c.val])
return postfixes
class PrefixMatcher(object):
def __init__(self, words: List[str]):
self.trie = build_trietree(words)
def __repr__(self):
return f'trie: \n-----\n{self.trie}\n-----'
def match(self, prefix: str) -> Optional[List[str]]:
node = self.trie.search(list(prefix), force_end=False)
return None if node is None else [prefix[:-1] + ''.join(a) for a in node.gather_postfixs()]
def build_trietree(words: List[str]) -> TrieTree:
root = TrieTree(is_end=False)
for w in words:
root.insert(list(w))
return root
def test():
words = 'aaa aac bsd rofnd rod aaadd aaade'.split(' ')
node = build_trietree(words)
for word in words:
print(f"haved: {word=} -> {node.search(list(word))}")
missing_words = "aa aad aaaddd rod1".split(' ')
for word in missing_words:
print(f"missing: {word=} -> {node.search(list(word))}")
node.search(list('aa'), force_end=False)
node.delete(list('aaadd'))
if __name__ == "__main__":
test()
@xiabingquan
Copy link
Author

Some improvements provided by AI chatbots, mainly focused on using for loop to replace recursion (not tested)

def search(self, word: List[str], force_end: bool = True) -> Optional[TrieTree]:
    current = self
    for char in word:
        found = False
        for child in current.children:
            if child.val == char:
                current = child
                found = True
                break
        if not found:
            return None
    return current if (not force_end or current.is_end) else None


def search(self, word: List[str], force_end: bool = True) -> Optional[TrieTree]:
    assert self._check_word(word)
    node = self
    for letter in word:
        node = next((c for c in node.children if c.val == letter), None)
        if node is None:
            return None
    return node if not force_end or node.is_end else None


def __recursive_delete(self, word: List[str]):
    if not word:
        self.children = [c for c in self.children if not c.is_dummy]
        return True
    
    if self.is_leaf or all(c.is_dummy for c in self.children):
        word.pop()
        self.mark_as_dummy()
        return self.parent.__recursive_delete(word) if self.parent else True
    
    self.children = [c for c in self.children if not c.is_dummy]
    return True


def __recursive_delete(self, word: List[str]):
    assert isinstance(word, list) and all(isinstance(w, str) for w in word)
    if not word:
        return False if not self.is_end else True
    letter = word.pop(0)
    for child in self.children:
        if child.val == letter:
            if child.__recursive_delete(word):
                if child.is_leaf:
                    self.children.remove(child)
                return True
    return False


def gather_postfixs(self) -> List[List[str]]:
    if self.is_leaf:
        return [[self.val]] if not self.is_root else []
    postfixes = []
    for child in self.children:
        for suffix in child.gather_postfixs():
            postfixes.append([self.val] + suffix)
        if child.is_end:
            postfixes.append([self.val] + [child.val])
    return postfixes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment