Skip to content

Instantly share code, notes, and snippets.

@tehybel
Created March 13, 2025 11:12
Show Gist options
  • Save tehybel/656a3f8fa8df20cf8665b1793fb1745a to your computer and use it in GitHub Desktop.
Save tehybel/656a3f8fa8df20cf8665b1793fb1745a to your computer and use it in GitHub Desktop.
Enabling tiled decode for SDXL in OneTrainer
# ... some lines from this file are left out ...
def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> StableDiffusionXLModel | None:
base_model_loader = StableDiffusionXLModelLoader()
lora_model_loader = StableDiffusionXLLoRALoader()
embedding_loader = StableDiffusionXLEmbeddingLoader()
model = StableDiffusionXLModel(model_type=model_type)
self._load_internal_data(model, model_names.lora)
model.model_spec = self._load_default_model_spec(model_type)
if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
lora_model_loader.load(model, model_names)
embedding_loader.load_multiple(model, model_names)
model.vae.enable_tiling()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment