from typing import Tuple, Dict, Optional
import GPUtil
import torch
import time
import os
from filelock import FileLock
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from concurrent.futures import ThreadPoolExecutor
import mlflow

# pip install gputil torch filelock typed-argument-parser fastapi uvicorn mlflow

# Constants
LOCK_DIR = os.path.expanduser("~/.gpu_locks")
LOCK_EXTENSION = ".lock"
WAIT_TIME = 10

app = FastAPI()
executor = ThreadPoolExecutor(max_workers=4)  # Limit the number of concurrent tasks
tasks_status: Dict[str, str] = {}  # Dictionary to track task status


def create_lock_dir() -> None:
    if not os.path.exists(LOCK_DIR):
        os.makedirs(LOCK_DIR)


def get_lock_file_path(gpu_id: int) -> str:
    return os.path.join(LOCK_DIR, f"gpu_{gpu_id}{LOCK_EXTENSION}")


def get_available_gpu() -> Tuple[int, FileLock]:
    create_lock_dir()

    while True:
        available_gpus = GPUtil.getAvailable(
            order="first", limit=1, maxLoad=0.05, maxMemory=0.05, includeNan=False
        )
        for gpu_id in available_gpus:
            lock_file = get_lock_file_path(gpu_id)
            lock = FileLock(lock_file)
            try:
                lock.acquire(timeout=0)  # Try to acquire the lock without waiting
                return gpu_id, lock
            except:
                continue

        print("No available GPUs. Waiting...")
        time.sleep(WAIT_TIME)


class TrainTask(BaseModel):
    learning_rate: float = 0.01  # Learning rate for the optimizer
    epochs: int = 10  # Number of epochs to train
    gpu_id: int = -1  # GPU ID to use, -1 for automatic allocation
    run_name: Optional[str] = None  # Optional run name for MLFlow


class DummyContextManager:
    def __enter__(self):
        pass  # No setup needed

    def __exit__(self, exc_type, exc_value, traceback):
        pass  # No cleanup needed

    def __bool__(self) -> bool:
        return False  # Make this object a False means "no lock"


def train_model(run_id: str, task: TrainTask):
    tasks_status[run_id] = "running"
    try:
        args = task

        if torch.cuda.is_available():
            if args.gpu_id == -1:
                gpu_id, lock = get_available_gpu()
            else:
                gpu_id = args.gpu_id
                lock_file = get_lock_file_path(gpu_id)
                lock = FileLock(lock_file)
                try:
                    lock.acquire(timeout=0)
                except:
                    print(f"GPU {gpu_id} is currently occupied. Waiting...")
                    while True:
                        try:
                            lock.acquire(timeout=0)
                            break
                        except:
                            time.sleep(WAIT_TIME)

            device = torch.device(f"cuda:{gpu_id}")
        else:
            device = torch.device("cpu")
            lock = DummyContextManager()

        print(f"Using device {device}")

        with lock:
            try:
                with mlflow.start_run(
                    run_id=run_id,
                    tags={
                        "Device": str(device),
                    },
                ):
                    # Example model and training loop
                    model = torch.nn.Linear(10, 1).to(device)
                    optimizer = torch.optim.SGD(
                        model.parameters(), lr=args.learning_rate
                    )
                    criterion = torch.nn.MSELoss()

                    # Log parameters
                    mlflow.log_param("learning_rate", args.learning_rate)
                    mlflow.log_param("epochs", args.epochs)

                    # Dummy data
                    data = torch.randn(100, 10).to(device)
                    target = torch.randn(100, 1).to(device)

                    # Training loop
                    for epoch in range(args.epochs):
                        optimizer.zero_grad()
                        output = model(data)
                        loss = criterion(output, target)
                        loss.backward()
                        optimizer.step()
                        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

                        # Log metrics
                        mlflow.log_metric("loss", loss.item(), step=epoch)
            except Exception as e:
                print(f"An error occurred: {e}")
                mlflow.log_param("error", str(e))
            finally:
                if lock:
                    print("Released lock for GPU")
                else:
                    print(
                        "No lock to be released. We don't create lock when we are using CPU."
                    )
    except Exception as e:
        print(f"An error occurred: {e}")
    finally:
        tasks_status[run_id] = "completed"


@app.post("/train")
def submit_training(task: TrainTask):
    # Alternatives
    # client = mlflow.MlflowClient()
    # run = client.create_run(experiment_id=mlflow.tracking.fluent._get_experiment_id(), run_name=task.run_name)
    with mlflow.start_run(run_name=task.run_name) as run:
        run_id = run.info.run_id
        tasks_status[run_id] = "pending"
        executor.submit(train_model, run_id, task)
    return {"message": "Training task has been submitted", "run_id": run_id}


@app.get("/status/{run_id}")
def get_task_status(run_id: str):
    if run_id not in tasks_status:
        raise HTTPException(status_code=404, detail="Run ID not found")

    try:
        run = mlflow.get_run(run_id)
        return {
            "run_id": run_id,
            "status": run.info.status,
            "start_time": run.info.start_time,
            "end_time": run.info.end_time,
            "metrics": run.data.metrics,
            "params": run.data.params,
            "tags": run.data.tags,
        }
    except Exception as e:
        raise HTTPException(
            status_code=500, detail=f"Could not retrieve run status: {e}"
        )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)