Created
May 14, 2025 15:20
-
-
Save petered/aa2b07ac8f38e638f63fae0ae838bc1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import struct | |
import time | |
import bisect | |
import multiprocessing as mp | |
from pathlib import Path | |
from typing import List, Optional, Iterator | |
import av | |
import numpy as np | |
from artemis.general.custom_types import BGRImageArray | |
from artemis.image_processing.decorders.idecorder import IDecorder | |
from image_annotation.file_utils import get_hash_for_file | |
# Tunables | |
_CHUNK = 8192 # bytes buffer for process writes | |
_REFRESH = 0.5 # seconds between on-disk index refreshes | |
_THRESHOLD = 100 # max frames beyond last keyframe before blocking | |
_PRINT_INTERVAL = 100 # frames between progress logs in builder | |
class AwaitingIndexingError(Exception): | |
"""Raised when requested frame is too far beyond the last keyframe index.""" | |
pass | |
def _index_worker(video_path: str, out_path: str, print_interval: int = _PRINT_INTERVAL): | |
"""Builds full PTS index in background into 'out_path'.""" | |
start = time.time() | |
print(f"[index] building index for {os.path.basename(video_path)} -> {out_path}") | |
# truncate existing | |
with open(out_path, 'wb'): pass | |
buf = bytearray(_CHUNK) | |
off = 0 | |
with av.open(video_path) as ct, open(out_path, 'r+b') as fh: | |
stream = ct.streams.video[0] | |
stream.codec_context.thread_count = 0 | |
stream.codec_context.thread_type = 'FRAME' | |
for i, frame in enumerate(ct.decode(stream)): | |
struct.pack_into('<q', buf, off, frame.pts) | |
off += 8 | |
if off == _CHUNK: | |
fh.write(buf) | |
off = 0 | |
if (i+1) % print_interval == 0: | |
elapsed = time.time() - start | |
print(f"[index] ...{i+1} frames in {int(elapsed)}s", flush=True) | |
if off: | |
fh.write(buf[:off]); off = 0 | |
if off: | |
fh.write(buf[:off]) | |
print("[index] finished", flush=True) | |
class ExactIndexingDecorder(IDecorder): | |
"""Frame-exact decoder with on-disk PTS index and lazy keyframe scanning.""" | |
def __init__( | |
self, | |
path: str, | |
cache_dir: Optional[str] = None, | |
threshold: int = _THRESHOLD, | |
print_interval: int = _PRINT_INTERVAL, | |
wait_for_index: bool = False, | |
): | |
self._path = os.path.expanduser(path) | |
if not os.path.exists(self._path): | |
raise FileNotFoundError(path) | |
# open main container for decoding | |
self._container = av.open(self._path) | |
self._stream = self._container.streams.video[0] | |
self._fps = float(self._stream.guessed_rate) | |
self._n_frames = self._stream.frames | |
# lazy keyframe iterator | |
self._keyframe_pts: List[int] = [0] | |
self._kf_complete = False | |
self._kf_container = av.open(self._path) | |
self._kf_stream = self._kf_container.streams.video[0] | |
self._kf_iter = self._kf_container.demux(self._kf_stream) | |
# on-disk PTS index builder | |
cache_root = Path(cache_dir or '~/.frame_index').expanduser() | |
cache_root.mkdir(exist_ok=True, parents=True) | |
vid_hash = get_hash_for_file(self._path).hex() | |
self._idx_file = cache_root / f"{vid_hash}.pts" | |
self._index: List[int] = [] | |
self._last_refresh = 0.0 | |
self._threshold = threshold | |
self._print_interval = print_interval | |
self._index_start = None | |
self._builder: Optional[mp.Process] = None | |
if self._idx_file.exists(): | |
self._load_full_index() | |
else: | |
self._index_start = time.time() | |
self._builder = mp.Process( | |
target=_index_worker, | |
args=(self._path, str(self._idx_file), self._print_interval) | |
) | |
self._builder.daemon = True | |
self._builder.start() | |
if wait_for_index: | |
self._builder.join() | |
if self._builder.exitcode != 0: | |
raise RuntimeError(f"Indexing failed for {self._path}") | |
self._load_full_index() | |
# state for fast sequential | |
self._dec_iter: Optional[Iterator] = None | |
self._last_idx: Optional[int] = None | |
def _load_full_index(self): | |
self._index = np.fromfile(self._idx_file, dtype='<i8').tolist() | |
def _refresh_index(self): | |
"""Append new PTS from disk into self._index.""" | |
if not self._idx_file.exists(): | |
return | |
start = len(self._index) * 8 | |
with open(self._idx_file, 'rb') as fh: | |
fh.seek(start) | |
data = fh.read() | |
if data: | |
pts = np.frombuffer(data, dtype='<i8') | |
if pts.size: | |
self._index.extend(pts.tolist()) | |
self._last_refresh = time.time() | |
def _ensure_keyframes(self, target_pts: int): | |
"""Scan packets incrementally until we have keyframes through target_pts.""" | |
if self._kf_complete: | |
return | |
for pkt in self._kf_iter: | |
if pkt.pts is None: | |
continue | |
if pkt.is_keyframe: | |
self._keyframe_pts.append(pkt.pts) | |
if pkt.pts >= target_pts: | |
break | |
else: | |
self._kf_complete = True | |
self._keyframe_pts = sorted(set(self._keyframe_pts)) | |
def __len__(self): | |
return self._n_frames | |
def __getitem__(self, idx: int) -> BGRImageArray: | |
# normalize negative | |
if idx < 0: | |
idx += self._n_frames | |
if not (0 <= idx < self._n_frames): | |
raise IndexError(idx) | |
# fast sequential: if we just decoded idx-1, grab next frame | |
if self._dec_iter is not None and self._last_idx == idx-1: | |
try: | |
frm = next(self._dec_iter) | |
self._last_idx = idx | |
return frm.to_ndarray(format='bgr24') | |
except StopIteration: | |
self._dec_iter = None | |
# ensure on-disk index up to date | |
if self._builder and self._builder.is_alive(): | |
if len(self._index) < self._n_frames: | |
if time.time() - (self._last_refresh or self._index_start) > _REFRESH: | |
self._refresh_index() | |
last_known = len(self._index) | |
if idx > last_known + self._threshold: | |
pct = min(100, 100 * last_known / self._n_frames) | |
eta = int((self._n_frames-last_known) / (last_known/(time.time()-self._index_start)+1e-6)) | |
raise AwaitingIndexingError(f"Waiting for indexing... {pct:.0f}% complete, ETA {eta}s") | |
# if random-access path available | |
if idx < len(self._index): | |
target_pts = self._index[idx] | |
self._ensure_keyframes(target_pts) | |
kpos = bisect.bisect_right(self._keyframe_pts, target_pts) - 1 | |
kf_pts = self._keyframe_pts[kpos] | |
key_idx = bisect.bisect_left(self._index, kf_pts) | |
offset = idx - key_idx | |
self._container.seek(kf_pts, stream=self._stream, any_frame=False) | |
dec = self._container.decode(self._stream) | |
for _ in range(offset): next(dec) | |
frm = next(dec) | |
self._dec_iter = dec | |
self._last_idx = idx | |
return frm.to_ndarray(format='bgr24') | |
# fallback: decode sequentially from nearest keyframe | |
# make sure we have all keyframes | |
self._ensure_keyframes(float('inf')) | |
last_kf = self._keyframe_pts[-1] | |
self._container.seek(last_kf, stream=self._stream, any_frame=False) | |
dec = self._container.decode(self._stream) | |
self._dec_iter = dec | |
self._last_idx = 0 | |
for _ in range(idx+1): | |
try: | |
frm = next(self._dec_iter) | |
except StopIteration: | |
raise StopIteration(f"Got stopiteration while requesting frame {idx} from {self._path}.") | |
self._last_idx += 1 | |
return frm.to_ndarray(format='bgr24') | |
def get_avg_fps(self) -> float: | |
return self._fps | |
def get_frame_timestamp(self, i: int) -> float: | |
return i / self._fps | |
def destroy(self): | |
self._container.close() | |
try: | |
self._kf_container.close() | |
except: | |
pass | |
if self._builder and self._builder.is_alive(): | |
self._builder.terminate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment