Last active
February 10, 2023 15:55
-
-
Save LiutongZhou/850b4c160cb55c2c80f427d855c13078 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Data Strutures that extend OrderedDict""" | |
from collections import Counter, OrderedDict | |
from typing import Any, Hashable, Optional, Tuple, List | |
from hypothesis import given, strategies as st | |
__all__ = ["OrderedDefaultDict", "MinMaxCounter"] | |
class OrderedDefaultDict(OrderedDict): | |
"""An Ordered DefaultDict""" | |
def __init__(self, *args, **kwargs): | |
"""Init an ordered default dict | |
Examples | |
-------- | |
>>> OrderedDefaultDict(list) | |
""" | |
self.default_factory = None | |
if not kwargs and len(args) == 1 and callable(default_factory := args[0]): | |
self.default_factory = default_factory | |
super().__init__() | |
else: | |
if callable(default_factory := kwargs.pop("default_factory", None)): | |
self.default_factory = default_factory | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, key: Hashable) -> Any: | |
if key in self: | |
return super().__getitem__(key) | |
elif self.default_factory is not None: | |
self[key] = default_item = self.default_factory() | |
return default_item | |
else: | |
raise KeyError(key) | |
def _peek(self, last=True) -> Optional[Tuple[Hashable, Any]]: | |
"""Return tail if last == True else return head | |
Time: O(1) | |
""" | |
if self: | |
key, value = self.popitem(last=last) | |
self[key] = value | |
self.move_to_end(key, last=last) | |
return key, value | |
else: | |
return None | |
def head(self) -> Optional[Tuple[Hashable, Any]]: | |
"""Return the head of the OrderedDict as (key, value)""" | |
return self._peek(last=False) | |
def tail(self) -> Optional[Tuple[Hashable, Any]]: | |
"""Return the tail of the OrderedDict as (key, value)""" | |
return self._peek(last=True) | |
class MinMaxCounter: | |
"""A Counter optimized for querying min / max keys by count frequency | |
Time: O(1) for increment, decrement, min, max | |
Notes | |
----- | |
This data structure maintains an OrderedDefaultDict(set) that maps | |
count frequency -> a set of items such that the min_freq, max_freq are | |
always its first and last key | |
""" | |
def __init__(self): | |
"""Init a MinMaxCounter""" | |
self.counter = Counter() | |
self.count_to_items = OrderedDefaultDict(set) | |
def increment(self, item: Hashable): | |
"""Increment the count for item""" | |
counter, count_to_items = self.counter, self.count_to_items | |
if freq_old := counter.get(item, None): | |
count_to_items[freq_old].discard(item) | |
if not count_to_items[freq_old]: | |
del count_to_items[freq_old] | |
counter.update((item,)) | |
freq_new = counter[item] | |
if freq_new in count_to_items or not count_to_items: | |
count_to_items[freq_new].add(item) | |
else: # count_to_items and freq_new not in count_to_items | |
# Maintain Min Max invariance of the count_to_items | |
max_freq, max_key_set = count_to_items.popitem() | |
count_to_items[freq_new].add(item) | |
count_to_items[max_freq] = max_key_set | |
if freq_new > max_freq: | |
count_to_items.move_to_end(freq_new) | |
else: | |
min_freq, min_key_set = count_to_items.head() | |
if freq_new < min_freq: | |
count_to_items.move_to_end(freq_new, last=False) | |
def decrement(self, item: Hashable): | |
"""Decrement the count for item""" | |
counter, count_to_items = self.counter, self.count_to_items | |
freq_old = counter[item] | |
count_to_items[freq_old].discard(item) | |
if not count_to_items[freq_old]: | |
del count_to_items[freq_old] | |
counter.subtract((item,)) | |
if (freq_new := counter[item]) == 0: | |
counter.pop(item) | |
else: | |
if freq_new in count_to_items or not count_to_items: | |
count_to_items[freq_new].add(item) | |
else: # count_to_items and freq_new not in count_to_items | |
# Maintain Min Max invariance of the count_to_items | |
min_freq, min_freq_set = count_to_items.popitem(last=False) | |
count_to_items[freq_new].add(item) | |
count_to_items.move_to_end(freq_new, last=False) | |
count_to_items[min_freq] = min_freq_set | |
count_to_items.move_to_end(min_freq, last=False) | |
if freq_new < min_freq: | |
count_to_items.move_to_end(freq_new, last=False) | |
def max(self) -> Optional[Tuple[int, set]]: | |
"""Return the (max_count, item_set)""" | |
return self.count_to_items.tail() | |
def min(self) -> Optional[Tuple[int, set]]: | |
"""Return the (min_count, item_set)""" | |
return self.count_to_items.head() | |
@given(st.lists(st.text())) | |
def test_minmaxcounter(str_list: List[str]): | |
"""Test MinMaxCounter""" | |
counter = Counter(str_list) | |
minmax_counter = MinMaxCounter() | |
for str_ in str_list: | |
minmax_counter.increment(str_) | |
def __check_min_max_value(counter, minmax_counter): | |
max_count_set = minmax_counter.max() | |
min_count_set = minmax_counter.min() | |
if counter: | |
key, value = counter.most_common(1)[0] | |
max_count, max_set = max_count_set | |
assert value == max_count | |
assert key in max_set | |
key = min(counter, key=counter.get) | |
value = counter[key] | |
min_count, min_set = min_count_set | |
assert value == min_count | |
assert key in min_set | |
else: | |
assert max_count_set is None | |
assert min_count_set is None | |
counter = +counter | |
__check_min_max_value(counter, minmax_counter) | |
for str_ in str_list: | |
minmax_counter.decrement(str_) | |
counter.subtract((str_,)) | |
counter = +counter | |
__check_min_max_value(counter, minmax_counter) | |
if __name__ == '__main__': | |
test_minmaxcounter() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment