from typing import Tuple, Dict, Optional import GPUtil import torch import time import os from filelock import FileLock from fastapi import FastAPI, HTTPException from pydantic import BaseModel from concurrent.futures import ThreadPoolExecutor import mlflow # pip install gputil torch filelock typed-argument-parser fastapi uvicorn mlflow # Constants LOCK_DIR = os.path.expanduser("~/.gpu_locks") LOCK_EXTENSION = ".lock" WAIT_TIME = 10 app = FastAPI() executor = ThreadPoolExecutor(max_workers=4) # Limit the number of concurrent tasks tasks_status: Dict[str, str] = {} # Dictionary to track task status def create_lock_dir() -> None: if not os.path.exists(LOCK_DIR): os.makedirs(LOCK_DIR) def get_lock_file_path(gpu_id: int) -> str: return os.path.join(LOCK_DIR, f"gpu_{gpu_id}{LOCK_EXTENSION}") def get_available_gpu() -> Tuple[int, FileLock]: create_lock_dir() while True: available_gpus = GPUtil.getAvailable( order="first", limit=1, maxLoad=0.05, maxMemory=0.05, includeNan=False ) for gpu_id in available_gpus: lock_file = get_lock_file_path(gpu_id) lock = FileLock(lock_file) try: lock.acquire(timeout=0) # Try to acquire the lock without waiting return gpu_id, lock except: continue print("No available GPUs. Waiting...") time.sleep(WAIT_TIME) class TrainTask(BaseModel): learning_rate: float = 0.01 # Learning rate for the optimizer epochs: int = 10 # Number of epochs to train gpu_id: int = -1 # GPU ID to use, -1 for automatic allocation run_name: Optional[str] = None # Optional run name for MLFlow class DummyContextManager: def __enter__(self): pass # No setup needed def __exit__(self, exc_type, exc_value, traceback): pass # No cleanup needed def __bool__(self) -> bool: return False # Make this object a False means "no lock" def train_model(run_id: str, task: TrainTask): tasks_status[run_id] = "running" try: args = task if torch.cuda.is_available(): if args.gpu_id == -1: gpu_id, lock = get_available_gpu() else: gpu_id = args.gpu_id lock_file = get_lock_file_path(gpu_id) lock = FileLock(lock_file) try: lock.acquire(timeout=0) except: print(f"GPU {gpu_id} is currently occupied. Waiting...") while True: try: lock.acquire(timeout=0) break except: time.sleep(WAIT_TIME) device = torch.device(f"cuda:{gpu_id}") else: device = torch.device("cpu") lock = DummyContextManager() print(f"Using device {device}") with lock: try: with mlflow.start_run( run_id=run_id, tags={ "Device": str(device), }, ): # Example model and training loop model = torch.nn.Linear(10, 1).to(device) optimizer = torch.optim.SGD( model.parameters(), lr=args.learning_rate ) criterion = torch.nn.MSELoss() # Log parameters mlflow.log_param("learning_rate", args.learning_rate) mlflow.log_param("epochs", args.epochs) # Dummy data data = torch.randn(100, 10).to(device) target = torch.randn(100, 1).to(device) # Training loop for epoch in range(args.epochs): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() print(f"Epoch {epoch + 1}, Loss: {loss.item()}") # Log metrics mlflow.log_metric("loss", loss.item(), step=epoch) except Exception as e: print(f"An error occurred: {e}") mlflow.log_param("error", str(e)) finally: if lock: print("Released lock for GPU") else: print( "No lock to be released. We don't create lock when we are using CPU." ) except Exception as e: print(f"An error occurred: {e}") finally: tasks_status[run_id] = "completed" @app.post("/train") def submit_training(task: TrainTask): # Alternatives # client = mlflow.MlflowClient() # run = client.create_run(experiment_id=mlflow.tracking.fluent._get_experiment_id(), run_name=task.run_name) with mlflow.start_run(run_name=task.run_name) as run: run_id = run.info.run_id tasks_status[run_id] = "pending" executor.submit(train_model, run_id, task) return {"message": "Training task has been submitted", "run_id": run_id} @app.get("/status/{run_id}") def get_task_status(run_id: str): if run_id not in tasks_status: raise HTTPException(status_code=404, detail="Run ID not found") try: run = mlflow.get_run(run_id) return { "run_id": run_id, "status": run.info.status, "start_time": run.info.start_time, "end_time": run.info.end_time, "metrics": run.data.metrics, "params": run.data.params, "tags": run.data.tags, } except Exception as e: raise HTTPException( status_code=500, detail=f"Could not retrieve run status: {e}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)