import torch from diffusers import LCMScheduler from diffusers import DiffusionPipeline, UNet2DConditionModel from huggingface_hub import hf_hub_download from safetensors.torch import load_file class TimestepShiftLCMScheduler(LCMScheduler): def __init__(self, *args, shifted_timestep=250, **kwargs): super().__init__(*args, **kwargs) self.register_to_config(shifted_timestep=shifted_timestep) def set_timesteps(self, *args, **kwargs): super().set_timesteps(*args, **kwargs) self.origin_timesteps = self.timesteps.clone() self.shifted_timesteps = ( self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps ).long() self.timesteps = self.shifted_timesteps def step(self, model_output, timestep, sample, generator=None, return_dict=True): if self.step_index is None: self._init_step_index(timestep) self.timesteps = self.origin_timesteps output = super().step(model_output, timestep, sample, generator, return_dict) self.timesteps = self.shifted_timesteps return output # Load model. base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" repo = "ChenDY/NitroFusion" # NitroSD-Realism ckpt = "nitrosd-realism_unet.safetensors" unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet") \ .to("cuda", torch.float16) unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) scheduler = TimestepShiftLCMScheduler.from_pretrained( base_model_id, subfolder="scheduler", shifted_timestep=250, ) scheduler.config.original_inference_steps = 4 base = DiffusionPipeline.from_pretrained( base_model_id, unet=unet, scheduler=scheduler, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ).to("cuda") refiner = DiffusionPipeline.from_pretrained( refiner_model_id, text_encoder_2=base.text_encoder_2, vae=base.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ).to("cuda") high_noise_frac = 0.8 while True: prompt = input('NDXL> ') image = base( prompt=prompt, num_inference_steps=1, guidance_scale=0, denoising_end=high_noise_frac, output_type="latent", ).images image = refiner( prompt=prompt, num_inference_steps=1, denoising_start=high_noise_frac, image=image, ).images[0] image.save('temp.png', 'PNG')