Last active
July 15, 2023 17:59
-
-
Save steinelu/caba60bfb22b9675590134137896c382 to your computer and use it in GitHub Desktop.
Segment Tree Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| class SegmentTree: | |
| def __init__(self, arr, function, neutral_element): | |
| self.m = len(arr) | |
| k = math.ceil(math.log2(self.m)) + 1 | |
| self.n = 2**k | |
| self.arr = arr | |
| self.tree = [neutral_element for _ in range(self.n)] | |
| self.foo = function | |
| self.neutral_element = neutral_element | |
| self._initialize(1) | |
| def _initialize(self, pos): | |
| if pos >= self.n / 2: # leaf | |
| i = pos - self.n // 2 | |
| self.tree[pos] = self.arr[i] | |
| return | |
| self._initialize(2*pos) | |
| self._initialize(2*pos+1) | |
| self.tree[pos] = self.foo(self.tree[2*pos], self.tree[2*pos+1]) | |
| def set(self, index, value): | |
| def _set(pos, l, r): | |
| if l == r: | |
| self.tree[pos] = value | |
| return | |
| m = (l + r) // 2 | |
| if index < m: | |
| _set(2*pos, l, m) | |
| else: | |
| _set(2*pos + 1, m+1, r) | |
| self.tree[pos] = self.foo(self.tree[2*pos], self.tree[2*pos+1]) | |
| _set(1, 1, self.m) | |
| def query(self, l, r): | |
| def _query(pos, l_, r_): | |
| if l_ >= l and r_ <= r: | |
| return self.tree[pos] | |
| if r_ < l or r < l_: | |
| return self.neutral_element | |
| m = (l_ + r_) // 2 | |
| return self.foo(_query(2*pos, l_, m), _query(2*pos + 1, m+1, r_)) | |
| return _query(1, 1, self.m) | |
| def graphvis(self) -> str: | |
| dot = "graph{\n" + "".join([f"""{i} [label="{self.tree[i]}"];\n""" for i in range(self.n)]) | |
| def w(pos): | |
| nonlocal dot | |
| if pos >= self.n//2: | |
| return | |
| dot += str(pos) + " -- " + str(2*pos) + ";\n" | |
| dot += str(pos) + " -- " + str(2*pos+1) + ";\n" | |
| w(2*pos) | |
| w(2*pos+1) | |
| w(1) | |
| dot += "}" | |
| return dot | |
| def __str__(self) -> str: | |
| s = "" | |
| def w(pos, depth): | |
| if pos >= self.n: | |
| return | |
| nonlocal s | |
| s += depth * "- " | |
| s += str(self.tree[pos]) + "\n" | |
| w(2*pos, depth+1) | |
| w(2*pos+1, depth+1) | |
| w(1, 0) | |
| return s[:-1] | |
| argmin = lambda x, y: x if x[1] < y[1] else y | |
| neutral_argmin = (None, math.inf) | |
| if __name__ == "__main__": | |
| vals = [4, 5, 7, 9, 4, 6, 2, 8, 9, 1, 5, 2, 7, 2, 4, 5] | |
| vals = list(enumerate(vals)) | |
| st = SegmentTree(vals, argmin, neutral_argmin) | |
| print(f"Q 1 3 : {st.query(1, 3)}") | |
| print(f"Q 4 16 : {st.query(4, 16)}") | |
| print(f"Q 1 16 : {st.query(1, 16)}") | |
| print(f"Q 11 16 : {st.query(11, 16)}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment