Skip to content

Instantly share code, notes, and snippets.

@petered
Created May 14, 2025 15:20
Show Gist options
  • Save petered/aa2b07ac8f38e638f63fae0ae838bc1f to your computer and use it in GitHub Desktop.
Save petered/aa2b07ac8f38e638f63fae0ae838bc1f to your computer and use it in GitHub Desktop.
import os
import struct
import time
import bisect
import multiprocessing as mp
from pathlib import Path
from typing import List, Optional, Iterator
import av
import numpy as np
from artemis.general.custom_types import BGRImageArray
from artemis.image_processing.decorders.idecorder import IDecorder
from image_annotation.file_utils import get_hash_for_file
# Tunables
_CHUNK = 8192 # bytes buffer for process writes
_REFRESH = 0.5 # seconds between on-disk index refreshes
_THRESHOLD = 100 # max frames beyond last keyframe before blocking
_PRINT_INTERVAL = 100 # frames between progress logs in builder
class AwaitingIndexingError(Exception):
"""Raised when requested frame is too far beyond the last keyframe index."""
pass
def _index_worker(video_path: str, out_path: str, print_interval: int = _PRINT_INTERVAL):
"""Builds full PTS index in background into 'out_path'."""
start = time.time()
print(f"[index] building index for {os.path.basename(video_path)} -> {out_path}")
# truncate existing
with open(out_path, 'wb'): pass
buf = bytearray(_CHUNK)
off = 0
with av.open(video_path) as ct, open(out_path, 'r+b') as fh:
stream = ct.streams.video[0]
stream.codec_context.thread_count = 0
stream.codec_context.thread_type = 'FRAME'
for i, frame in enumerate(ct.decode(stream)):
struct.pack_into('<q', buf, off, frame.pts)
off += 8
if off == _CHUNK:
fh.write(buf)
off = 0
if (i+1) % print_interval == 0:
elapsed = time.time() - start
print(f"[index] ...{i+1} frames in {int(elapsed)}s", flush=True)
if off:
fh.write(buf[:off]); off = 0
if off:
fh.write(buf[:off])
print("[index] finished", flush=True)
class ExactIndexingDecorder(IDecorder):
"""Frame-exact decoder with on-disk PTS index and lazy keyframe scanning."""
def __init__(
self,
path: str,
cache_dir: Optional[str] = None,
threshold: int = _THRESHOLD,
print_interval: int = _PRINT_INTERVAL,
wait_for_index: bool = False,
):
self._path = os.path.expanduser(path)
if not os.path.exists(self._path):
raise FileNotFoundError(path)
# open main container for decoding
self._container = av.open(self._path)
self._stream = self._container.streams.video[0]
self._fps = float(self._stream.guessed_rate)
self._n_frames = self._stream.frames
# lazy keyframe iterator
self._keyframe_pts: List[int] = [0]
self._kf_complete = False
self._kf_container = av.open(self._path)
self._kf_stream = self._kf_container.streams.video[0]
self._kf_iter = self._kf_container.demux(self._kf_stream)
# on-disk PTS index builder
cache_root = Path(cache_dir or '~/.frame_index').expanduser()
cache_root.mkdir(exist_ok=True, parents=True)
vid_hash = get_hash_for_file(self._path).hex()
self._idx_file = cache_root / f"{vid_hash}.pts"
self._index: List[int] = []
self._last_refresh = 0.0
self._threshold = threshold
self._print_interval = print_interval
self._index_start = None
self._builder: Optional[mp.Process] = None
if self._idx_file.exists():
self._load_full_index()
else:
self._index_start = time.time()
self._builder = mp.Process(
target=_index_worker,
args=(self._path, str(self._idx_file), self._print_interval)
)
self._builder.daemon = True
self._builder.start()
if wait_for_index:
self._builder.join()
if self._builder.exitcode != 0:
raise RuntimeError(f"Indexing failed for {self._path}")
self._load_full_index()
# state for fast sequential
self._dec_iter: Optional[Iterator] = None
self._last_idx: Optional[int] = None
def _load_full_index(self):
self._index = np.fromfile(self._idx_file, dtype='<i8').tolist()
def _refresh_index(self):
"""Append new PTS from disk into self._index."""
if not self._idx_file.exists():
return
start = len(self._index) * 8
with open(self._idx_file, 'rb') as fh:
fh.seek(start)
data = fh.read()
if data:
pts = np.frombuffer(data, dtype='<i8')
if pts.size:
self._index.extend(pts.tolist())
self._last_refresh = time.time()
def _ensure_keyframes(self, target_pts: int):
"""Scan packets incrementally until we have keyframes through target_pts."""
if self._kf_complete:
return
for pkt in self._kf_iter:
if pkt.pts is None:
continue
if pkt.is_keyframe:
self._keyframe_pts.append(pkt.pts)
if pkt.pts >= target_pts:
break
else:
self._kf_complete = True
self._keyframe_pts = sorted(set(self._keyframe_pts))
def __len__(self):
return self._n_frames
def __getitem__(self, idx: int) -> BGRImageArray:
# normalize negative
if idx < 0:
idx += self._n_frames
if not (0 <= idx < self._n_frames):
raise IndexError(idx)
# fast sequential: if we just decoded idx-1, grab next frame
if self._dec_iter is not None and self._last_idx == idx-1:
try:
frm = next(self._dec_iter)
self._last_idx = idx
return frm.to_ndarray(format='bgr24')
except StopIteration:
self._dec_iter = None
# ensure on-disk index up to date
if self._builder and self._builder.is_alive():
if len(self._index) < self._n_frames:
if time.time() - (self._last_refresh or self._index_start) > _REFRESH:
self._refresh_index()
last_known = len(self._index)
if idx > last_known + self._threshold:
pct = min(100, 100 * last_known / self._n_frames)
eta = int((self._n_frames-last_known) / (last_known/(time.time()-self._index_start)+1e-6))
raise AwaitingIndexingError(f"Waiting for indexing... {pct:.0f}% complete, ETA {eta}s")
# if random-access path available
if idx < len(self._index):
target_pts = self._index[idx]
self._ensure_keyframes(target_pts)
kpos = bisect.bisect_right(self._keyframe_pts, target_pts) - 1
kf_pts = self._keyframe_pts[kpos]
key_idx = bisect.bisect_left(self._index, kf_pts)
offset = idx - key_idx
self._container.seek(kf_pts, stream=self._stream, any_frame=False)
dec = self._container.decode(self._stream)
for _ in range(offset): next(dec)
frm = next(dec)
self._dec_iter = dec
self._last_idx = idx
return frm.to_ndarray(format='bgr24')
# fallback: decode sequentially from nearest keyframe
# make sure we have all keyframes
self._ensure_keyframes(float('inf'))
last_kf = self._keyframe_pts[-1]
self._container.seek(last_kf, stream=self._stream, any_frame=False)
dec = self._container.decode(self._stream)
self._dec_iter = dec
self._last_idx = 0
for _ in range(idx+1):
try:
frm = next(self._dec_iter)
except StopIteration:
raise StopIteration(f"Got stopiteration while requesting frame {idx} from {self._path}.")
self._last_idx += 1
return frm.to_ndarray(format='bgr24')
def get_avg_fps(self) -> float:
return self._fps
def get_frame_timestamp(self, i: int) -> float:
return i / self._fps
def destroy(self):
self._container.close()
try:
self._kf_container.close()
except:
pass
if self._builder and self._builder.is_alive():
self._builder.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment