Skip to content

Instantly share code, notes, and snippets.

@steinelu
Last active July 15, 2023 17:59
Show Gist options
  • Save steinelu/caba60bfb22b9675590134137896c382 to your computer and use it in GitHub Desktop.
Save steinelu/caba60bfb22b9675590134137896c382 to your computer and use it in GitHub Desktop.
Segment Tree Python
import math
class SegmentTree:
def __init__(self, arr, function, neutral_element):
self.m = len(arr)
k = math.ceil(math.log2(self.m)) + 1
self.n = 2**k
self.arr = arr
self.tree = [neutral_element for _ in range(self.n)]
self.foo = function
self.neutral_element = neutral_element
self._initialize(1)
def _initialize(self, pos):
if pos >= self.n / 2: # leaf
i = pos - self.n // 2
self.tree[pos] = self.arr[i]
return
self._initialize(2*pos)
self._initialize(2*pos+1)
self.tree[pos] = self.foo(self.tree[2*pos], self.tree[2*pos+1])
def set(self, index, value):
def _set(pos, l, r):
if l == r:
self.tree[pos] = value
return
m = (l + r) // 2
if index < m:
_set(2*pos, l, m)
else:
_set(2*pos + 1, m+1, r)
self.tree[pos] = self.foo(self.tree[2*pos], self.tree[2*pos+1])
_set(1, 1, self.m)
def query(self, l, r):
def _query(pos, l_, r_):
if l_ >= l and r_ <= r:
return self.tree[pos]
if r_ < l or r < l_:
return self.neutral_element
m = (l_ + r_) // 2
return self.foo(_query(2*pos, l_, m), _query(2*pos + 1, m+1, r_))
return _query(1, 1, self.m)
def graphvis(self) -> str:
dot = "graph{\n" + "".join([f"""{i} [label="{self.tree[i]}"];\n""" for i in range(self.n)])
def w(pos):
nonlocal dot
if pos >= self.n//2:
return
dot += str(pos) + " -- " + str(2*pos) + ";\n"
dot += str(pos) + " -- " + str(2*pos+1) + ";\n"
w(2*pos)
w(2*pos+1)
w(1)
dot += "}"
return dot
def __str__(self) -> str:
s = ""
def w(pos, depth):
if pos >= self.n:
return
nonlocal s
s += depth * "- "
s += str(self.tree[pos]) + "\n"
w(2*pos, depth+1)
w(2*pos+1, depth+1)
w(1, 0)
return s[:-1]
argmin = lambda x, y: x if x[1] < y[1] else y
neutral_argmin = (None, math.inf)
if __name__ == "__main__":
vals = [4, 5, 7, 9, 4, 6, 2, 8, 9, 1, 5, 2, 7, 2, 4, 5]
vals = list(enumerate(vals))
st = SegmentTree(vals, argmin, neutral_argmin)
print(f"Q 1 3 : {st.query(1, 3)}")
print(f"Q 4 16 : {st.query(4, 16)}")
print(f"Q 1 16 : {st.query(1, 16)}")
print(f"Q 11 16 : {st.query(11, 16)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment