import torch
from diffusers import LCMScheduler
from diffusers import DiffusionPipeline, UNet2DConditionModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


class TimestepShiftLCMScheduler(LCMScheduler):
    def __init__(self, *args, shifted_timestep=250, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_to_config(shifted_timestep=shifted_timestep)
    def set_timesteps(self, *args, **kwargs):
        super().set_timesteps(*args, **kwargs)
        self.origin_timesteps = self.timesteps.clone()
        self.shifted_timesteps = (
            self.timesteps * self.config.shifted_timestep /
            self.config.num_train_timesteps
        ).long()
        self.timesteps = self.shifted_timesteps
    def step(self, model_output, timestep, sample, generator=None, return_dict=True):
        if self.step_index is None:
            self._init_step_index(timestep)
        self.timesteps = self.origin_timesteps
        output = super().step(model_output, timestep, sample, generator, return_dict)
        self.timesteps = self.shifted_timesteps
        return output

# Load model.
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
repo = "ChenDY/NitroFusion"

# NitroSD-Realism
ckpt = "nitrosd-realism_unet.safetensors"
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet") \
    .to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
scheduler = TimestepShiftLCMScheduler.from_pretrained(
    base_model_id, subfolder="scheduler",
    shifted_timestep=250,
)
scheduler.config.original_inference_steps = 4

base = DiffusionPipeline.from_pretrained(
    base_model_id,
    unet=unet,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
).to("cuda")

refiner = DiffusionPipeline.from_pretrained(
    refiner_model_id,
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to("cuda")

high_noise_frac = 0.8

while True:
    prompt = input('NDXL> ')
    image = base(
        prompt=prompt,
        num_inference_steps=1,
        guidance_scale=0,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images
    image = refiner(
        prompt=prompt,
        num_inference_steps=1,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]
    image.save('temp.png', 'PNG')