Skip to content

Instantly share code, notes, and snippets.

@crypdick
Created July 23, 2025 00:43
Show Gist options
  • Save crypdick/f91b906d2c66dee4ff0a7a0d6128de8d to your computer and use it in GitHub Desktop.
Save crypdick/f91b906d2c66dee4ff0a7a0d6128de8d to your computer and use it in GitHub Desktop.
Actor that integrates GPU capacity over time across a Ray cluster.
import threading
import time
import ray
@ray.remote(num_cpus=0)
class GPUHoursTracker:
"""Actor that integrates GPU capacity over time across a Ray cluster.
Starts a background thread that periodically queries `ray.cluster_resources`.
At every tick it assumes that the current number of GPUs is constant for the
following polling interval and adds that amount to an accumulator.
"""
def __init__(self, polling_interval_s: float = 5.0):
"""Note: don't poll too often, since it is expensive to call ray.cluster_resources()."""
self._interval = polling_interval_s
self._total_gpu_seconds: float = 0.0
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._last_sample_t = -1 # will be set by start()
self._thread = threading.Thread(target=self._poll_loop, daemon=True)
def start(self) -> float:
"""Start the background polling thread.
This separate method is needed because Ray returns the actor handle
before __init__ finishes execution. This method guarantees that the actor
finishes initialization and starts polling before the caller resumes
execution.
Use the returned timestamp to mark the start of the job.
"""
if self._thread.is_alive():
# Idempotent: if already started, just return the original start time.
return self._last_sample_t
self._last_sample_t = float(time.time())
# Start polling.
self._thread.start()
return self._last_sample_t
def get_total_gpu_sec(self) -> float:
# Block read if the accumulator is being updated.
with self._lock:
return self._total_gpu_seconds
def stop(self) -> float:
"""Stop the background thread and wait for it to finish."""
# Break the polling loop.
self._stop_event.set()
# Account for the final partial interval.
self._poll_once()
# Wait until the polling thread exits cleanly.
self._thread.join()
stop_time = float(time.time())
return stop_time
def __del__(self):
try:
self.stop()
except Exception:
pass
def _poll_loop(self):
"""Background thread that periodically samples cluster GPU usage."""
# Initialize the timestamp of the very first sample.
self._last_sample_t = time.time()
# Sleep for interval, break on stop signal.
# Final interval is accounted for by stop()
while not self._stop_event.wait(self._interval):
self._poll_once()
def _poll_once(self):
"""Sample cluster GPU capacity and update the accumulator."""
now = time.time()
elapsed = now - self._last_sample_t
self._last_sample_t = now
gpu_total = ray.cluster_resources().get("GPU", 0.0)
# Block write if the accumulator is being read.
with self._lock:
self._total_gpu_seconds += gpu_total * elapsed
if __name__ == "__main__":
# Setup what we can before we start timers.
duration = 3
ray.init()
gpu_total = ray.cluster_resources().get("GPU", 0.0)
theoretical_gpu_sec = gpu_total * duration
global_start = time.time()
tracker = GPUHoursTracker.remote()
actor_start_ts: float = ray.get(tracker.start.remote())
print(f"Testing GPUHoursTracker for {duration} seconds")
time.sleep(duration)
stop_time = ray.get(tracker.stop.remote()) # type: ignore[attr-defined]
actor_start_delay = actor_start_ts - global_start
print(f"Actor start delay: {actor_start_delay:.2f}s")
measured_gpu_sec: float = ray.get(tracker.get_total_gpu_sec.remote())
theoretical_gpu_sec_discounting_overhead = theoretical_gpu_sec - actor_start_delay
print(f"Measured GPU-sec: {measured_gpu_sec}")
print(f"Theoretical expected: {theoretical_gpu_sec}")
elapsed_since_actor_start = stop_time - actor_start_ts
print(
f"Elapsed seconds since polling started: {elapsed_since_actor_start:.2f}s ({elapsed_since_actor_start * duration} GPU-sec)"
)
print(f"Elapsed seconds total: {stop_time - global_start:.2f}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment