Created
December 10, 2024 19:50
-
-
Save deckar01/7a8bbda3554d5e7dd6b316185366a4e8 to your computer and use it in GitHub Desktop.
NitroDiffusion + One Step Refiner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from diffusers import LCMScheduler | |
from diffusers import DiffusionPipeline, UNet2DConditionModel | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
class TimestepShiftLCMScheduler(LCMScheduler): | |
def __init__(self, *args, shifted_timestep=250, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.register_to_config(shifted_timestep=shifted_timestep) | |
def set_timesteps(self, *args, **kwargs): | |
super().set_timesteps(*args, **kwargs) | |
self.origin_timesteps = self.timesteps.clone() | |
self.shifted_timesteps = ( | |
self.timesteps * self.config.shifted_timestep / | |
self.config.num_train_timesteps | |
).long() | |
self.timesteps = self.shifted_timesteps | |
def step(self, model_output, timestep, sample, generator=None, return_dict=True): | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
self.timesteps = self.origin_timesteps | |
output = super().step(model_output, timestep, sample, generator, return_dict) | |
self.timesteps = self.shifted_timesteps | |
return output | |
# Load model. | |
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" | |
repo = "ChenDY/NitroFusion" | |
# NitroSD-Realism | |
ckpt = "nitrosd-realism_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet") \ | |
.to("cuda", torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) | |
scheduler = TimestepShiftLCMScheduler.from_pretrained( | |
base_model_id, subfolder="scheduler", | |
shifted_timestep=250, | |
) | |
scheduler.config.original_inference_steps = 4 | |
base = DiffusionPipeline.from_pretrained( | |
base_model_id, | |
unet=unet, | |
scheduler=scheduler, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to("cuda") | |
refiner = DiffusionPipeline.from_pretrained( | |
refiner_model_id, | |
text_encoder_2=base.text_encoder_2, | |
vae=base.vae, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
).to("cuda") | |
high_noise_frac = 0.8 | |
while True: | |
prompt = input('NDXL> ') | |
image = base( | |
prompt=prompt, | |
num_inference_steps=1, | |
guidance_scale=0, | |
denoising_end=high_noise_frac, | |
output_type="latent", | |
).images | |
image = refiner( | |
prompt=prompt, | |
num_inference_steps=1, | |
denoising_start=high_noise_frac, | |
image=image, | |
).images[0] | |
image.save('temp.png', 'PNG') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment