Last active
June 15, 2025 13:38
-
-
Save blepping/d424e8fd27d76845ad27997820a57f6b to your computer and use it in GitHub Desktop.
Experimental FBG guidance sampler node for ComfyUI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Referenced from https://github.com/FelixKoulischer/Feedback-Guidance-of-Diffusion-Models/ | |
# FBG paper: https://arxiv.org/abs/2506.06085 | |
# ComfyUI implementation by https://github.com/blepping | |
# LICENSE: Apache2 | |
# Usage: Place this file in the custom_nodes directory and restart ComfyUI+refresh browser. | |
# It will add a FBGSampler node that can be used with SamplerCustom, etc. | |
import math | |
from enum import Enum, auto | |
from typing import NamedTuple | |
import torch | |
from tqdm import tqdm | |
from comfy import model_sampling, model_patcher | |
from comfy.samplers import KSAMPLER, cfg_function | |
from comfy.k_diffusion.sampling import get_ancestral_step | |
F = torch.nn.functional | |
class SamplerMode(Enum): | |
EULER = auto() | |
PINGPONG = auto() | |
class FBGConfig(NamedTuple): | |
sampler_mode: SamplerMode = SamplerMode.EULER | |
cfg_start_sigma: float = 9999.0 | |
cfg_end_sigma: float = 0.0 | |
fbg_start_sigma: float = 9999.0 | |
fbg_end_sigma: float = 0.0 | |
fbg_guidance_multiplier: float = 1.0 | |
ancestral_start_sigma: float = 9999.0 | |
ancestral_end_sigma: float = 0.0 | |
cfg_scale: float = 1.0 | |
max_guidance_scale: float = 10.0 | |
max_posterior_scale: float = 3.0 | |
initial_value: float = 0.0 | |
initial_guidance_scale: float = 1.0 | |
guidance_max_change: float = 1000.0 | |
temp: float = 0.0 | |
offset: float = 0.0 | |
pi: float = 0.95 | |
t_0: float = 0.5 | |
t_1: float = 0.4 | |
def batch_mse_loss(a: torch.Tensor, b: torch.Tensor, *, start_dim=1) -> torch.Tensor: | |
return ((a - b) ** 2).sum(dim=tuple(range(start_dim, a.ndim))) | |
class FBGSampler: | |
def __init__( | |
self, | |
model, | |
x, | |
sigmas, | |
extra_args=None, | |
callback=None, | |
disable=None, | |
noise_sampler=None, | |
eta=1.0, | |
s_noise=1.0, | |
fbg_config: FBGConfig = FBGConfig(), | |
**kwargs, | |
): | |
self.model_ = model | |
self.sigmas = sigmas | |
self.x = x | |
self.s_in = x.new_ones((x.shape[0],)) | |
self.extra_args = extra_args if extra_args is not None else {} | |
self.disable = disable | |
self.callback_ = callback | |
self.config = fbg_config | |
self.update_config() | |
cfg = fbg_config | |
if cfg.cfg_scale > 1 and cfg.cfg_start_sigma > 0: | |
self.minimal_log_posterior = math.log( | |
(1.0 - cfg.pi) | |
* (cfg.max_guidance_scale - cfg.cfg_scale + 1) | |
/ (cfg.max_guidance_scale - cfg.cfg_scale) | |
) | |
else: | |
self.minimal_log_posterior = math.log( | |
(1.0 - cfg.pi) * cfg.max_guidance_scale / (cfg.max_guidance_scale - 1.0) | |
) | |
self.eta = max(0.0, eta) | |
self.s_noise = s_noise | |
self.is_rf = isinstance( | |
model.inner_model.inner_model.model_sampling, model_sampling.CONST | |
) | |
if self.eta == 0: | |
self.noise_sampler = None | |
else: | |
if noise_sampler is None: | |
def noise_sampler(*_unused): | |
return torch.randn_like(x) | |
self.noise_sampler = noise_sampler | |
@classmethod | |
def go( | |
cls, | |
model, | |
x, | |
sigmas, | |
extra_args=None, | |
callback=None, | |
disable=None, | |
noise_sampler=None, | |
eta=1.0, | |
s_noise=1.0, | |
fbg_config: FBGConfig = FBGConfig(), | |
**kwargs, | |
): | |
return cls( | |
model, | |
x, | |
sigmas, | |
extra_args=extra_args, | |
callback=callback, | |
disable=disable, | |
noise_sampler=noise_sampler, | |
eta=eta, | |
s_noise=s_noise, | |
fbg_config=fbg_config, | |
**kwargs, | |
)() | |
def model( | |
self, x: torch.Tensor, sigma: torch.Tensor, *, override_cfg=None, **kwargs: dict | |
) -> tuple: | |
sigma = sigma * self.s_in | |
cond = uncond = None | |
def post_cfg_function(args): | |
nonlocal cond, uncond | |
cond, uncond = args["cond_denoised"], args["uncond_denoised"] | |
return args["denoised"] | |
extra_args = self.extra_args.copy() | |
orig_model_options = extra_args.get("model_options", {}) | |
model_options = orig_model_options.copy() | |
model_options["disable_cfg1_optimization"] = True | |
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function( | |
model_options, post_cfg_function | |
) | |
inner_model = self.model_.inner_model | |
if (override_cfg is None or len(override_cfg) < 2) and hasattr( | |
inner_model, "cfg" | |
): | |
orig_cfg = inner_model.cfg | |
try: | |
if override_cfg is not None: | |
inner_model.cfg = override_cfg.detach().item() | |
denoised = inner_model.predict_noise( | |
x, | |
sigma, | |
model_options=extra_args["model_options"], | |
seed=extra_args.get("seed"), | |
) | |
finally: | |
inner_model.cfg = orig_cfg | |
else: | |
_ = self.model_(x, sigma, **extra_args, **kwargs) | |
denoised = cfg_function( | |
inner_model.inner_model, | |
cond, | |
uncond, | |
override_cfg, | |
x, | |
sigma, | |
model_options=orig_model_options, | |
) | |
return denoised, cond, uncond | |
def callback(self, idx, x, sigma, denoised): | |
if self.callback_ is None: | |
return | |
self.callback_({ | |
"i": idx, | |
"x": x, | |
"sigma": sigma, | |
"sigma_hat": sigma, | |
"denoised": denoised, | |
}) | |
def update_log_posterior( | |
self, | |
prev: torch.Tensor, | |
x_curr: torch.Tensor, | |
x_next: torch.Tensor, | |
t_curr: torch.Tensor, | |
t_next: torch.Tensor, | |
uncond: torch.Tensor, | |
cond: torch.Tensor, | |
) -> torch.Tensor: | |
t_csq = t_curr**2 | |
t_ndc = t_next**2 / t_csq | |
t_cmn = t_csq - t_next**2 | |
sigma_square_tilde_t = t_cmn * t_ndc | |
pred_base = t_ndc * x_curr | |
uncond_pred_mean = pred_base + t_cmn / t_csq * uncond | |
cond_pred_mean = pred_base + t_cmn / t_csq * cond | |
diff = batch_mse_loss(x_next, cond_pred_mean) - batch_mse_loss( | |
x_next, uncond_pred_mean | |
) | |
result = ( | |
prev | |
- self.config.temp / (2 * sigma_square_tilde_t) * diff | |
+ self.config.offset | |
) | |
return result.clamp_( | |
self.minimal_log_posterior, self.config.max_posterior_scale | |
) | |
def get_sigma_square_tilde(self, sigmas: torch.Tensor) -> torch.Tensor: | |
s_sq, sn_sq = sigmas[:-1] ** 2, sigmas[1:] ** 2 | |
return ((s_sq - sn_sq) * sn_sq / s_sq).flip(dims=(0,)) | |
def get_offset( | |
self, | |
steps: int, | |
sigma_square_tilde: torch.Tensor, | |
*, | |
lambda_ref=3.0, | |
decimals=4, | |
): | |
cfg = self.config | |
result = ( | |
1.0 | |
/ ((1.0 - cfg.t_0) * steps) | |
* math.log((1.0 - cfg.pi) * lambda_ref / (lambda_ref - 1.0)) | |
) | |
return round(result, decimals) | |
def get_temp( | |
self, | |
steps: int, | |
offset: float, | |
sigma_square_tilde: torch.Tensor, | |
*, | |
alpha=10.0, | |
decimals=4, | |
): | |
cfg = self.config | |
t1_lower = int(math.floor(cfg.t_1 * steps)) | |
sst_t1, sst_t1_next = ( | |
sigma_square_tilde[t1_lower], | |
sigma_square_tilde[t1_lower + 1], | |
) | |
a = cfg.t_1 * steps - t1_lower | |
sst = torch.lerp(sst_t1, sst_t1_next, a) | |
temp = (2 * sst / alpha * offset).abs().item() | |
return round(temp, decimals) | |
def update_config(self): | |
if self.config.t_0 == 0 and self.config.t_1 == 0: | |
return | |
sigmas = self.sigmas | |
steps = len(sigmas) - 1 | |
sst = self.get_sigma_square_tilde(sigmas) | |
offset = self.get_offset(steps, sst) | |
temp = self.get_temp(steps, offset, sst) | |
kwargs = self.config._asdict() | {"offset": offset, "temp": temp} | |
self.config = self.config.__class__(**kwargs) | |
def sampler_step(self, x, denoised, sigma, sigma_next, eta): | |
noise_sampler = self.noise_sampler | |
config = self.config | |
if noise_sampler is None or eta == 0: | |
sigma_down = sigma_next | |
renoise_coeff = sigma_next * 0 | |
elif self.is_rf: | |
# Adaptedfrom ComfyUI | |
downstep_ratio = 1 + (sigma_next / sigma - 1) * eta | |
sigma_down = sigma_next * downstep_ratio | |
alpha_ip1 = 1 - sigma_next | |
alpha_down = 1 - sigma_down | |
if torch.isclose(sigma_down, sigma_next): | |
sigma_down = sigma_next | |
renoise_coeff = sigma_next * 0 | |
else: | |
renoise_coeff = ( | |
sigma_next**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 | |
) ** 0.5 | |
else: | |
sigma_down, renoise_coeff = get_ancestral_step(sigma, sigma_next, eta=eta) | |
if sigma_down == sigma_next: | |
renoise_coeff = sigma_next * 0 | |
if noise_sampler is not None and renoise_coeff.sum() != 0: | |
noise = noise_sampler(sigma, sigma_next).mul_(self.s_noise) | |
else: | |
noise = None | |
if noise is not None and config.sampler_mode == SamplerMode.PINGPONG: | |
if self.is_rf: | |
x = torch.lerp(denoised, noise, sigma_next) | |
else: | |
x = noise.mul_(sigma_next).add_(denoised) | |
else: | |
x = torch.lerp(denoised, x, sigma_down / sigma) | |
if noise is not None: | |
noise *= renoise_coeff | |
if not self.is_rf: | |
x += noise | |
else: | |
x = x.mul_(alpha_ip1 / alpha_down).add_(noise) | |
return x | |
def get_guidance_scale( | |
self, | |
log_posterior: torch.Tensor, | |
guidance_scale_prev: torch.Tensor, | |
sigma: float, | |
sigma_next: float, | |
) -> torch.Tensor: | |
config = self.config | |
using_fbg = config.fbg_end_sigma <= sigma <= config.fbg_start_sigma | |
using_cfg = config.cfg_scale != 1 and ( | |
config.cfg_end_sigma <= sigma <= config.cfg_start_sigma | |
) | |
if using_fbg: | |
guidance_scale = log_posterior.exp() | |
guidance_scale /= guidance_scale - (1.0 - config.pi) | |
guidance_scale *= config.fbg_guidance_multiplier | |
else: | |
guidance_scale = log_posterior.new_ones(guidance_scale_prev.shape[0]) | |
if using_cfg: | |
guidance_scale += config.cfg_scale - 1.0 | |
guidance_scale = guidance_scale.clamp(1.0, config.max_guidance_scale).view( | |
guidance_scale_prev.shape | |
) | |
change_diff = guidance_scale - guidance_scale_prev | |
change_pct = (change_diff / guidance_scale).clamp_( | |
-config.guidance_max_change, config.guidance_max_change | |
) | |
guidance_scale_new = guidance_scale_prev + guidance_scale_prev * change_pct | |
return guidance_scale_new.clamp_(1.0, config.max_guidance_scale) | |
def __call__(self): | |
x = self.x | |
config = self.config | |
orig_eta = self.eta | |
batch = x.shape[0] | |
log_posterior = x.new_full((batch,), config.initial_value) | |
dim_1s = (1,) * (x.ndim - 1) | |
sigmas = self.sigmas | |
guidance_scale = x.new_full((batch, *dim_1s), config.initial_guidance_scale) | |
with tqdm( | |
initial=0, total=len(sigmas) - 1, disable=self.disable, leave=True | |
) as pbar: | |
for idx, (sigma, sigma_next) in enumerate(zip(sigmas[:-1], sigmas[1:])): | |
x_orig = x.clone() | |
sigma_item, sigma_next_item = ( | |
sigma.max().detach().item(), | |
sigma_next.min().detach().item(), | |
) | |
eta = ( | |
orig_eta | |
if config.ancestral_end_sigma | |
<= sigma_item | |
<= config.ancestral_start_sigma | |
else 0.0 | |
) | |
guidance_scale = self.get_guidance_scale( | |
log_posterior, guidance_scale, sigma_item, sigma_next_item | |
) | |
pretty_scales = ", ".join( | |
f"{scale:.02f}" | |
for scale in guidance_scale.flatten().detach().tolist() | |
) | |
pbar.set_description( | |
f"FBG {sigma_item:.03f} -> {sigma_next_item:.03f}, guidance: {pretty_scales}", | |
refresh=False, | |
) | |
if idx == 0: | |
pbar.refresh() | |
denoised, cond, uncond = self.model( | |
x, sigma, override_cfg=guidance_scale | |
) | |
pbar.update(1) | |
if sigma_next <= 1e-06: | |
return denoised | |
self.callback(idx, x, sigma, denoised) | |
x = self.sampler_step(x, denoised, sigma, sigma_next, eta) | |
log_posterior = self.update_log_posterior( | |
log_posterior, x_orig, x, sigma, sigma_next, uncond, cond | |
) | |
pbar.refresh() | |
return x | |
class FBGSamplerNode: | |
CATEGORY = "sampling/custom_sampling/samplers" | |
RETURN_TYPES = ("SAMPLER",) | |
FUNCTION = "go" | |
@classmethod | |
def INPUT_TYPES(cls): | |
defaults = FBGConfig() | |
return { | |
"required": { | |
"eta": ( | |
"FLOAT", | |
{ | |
"default": 0.0, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Must be above zero for ancestral or pingpong sampling to activate.", | |
}, | |
), | |
"s_noise": ( | |
"FLOAT", | |
{ | |
"default": 1.0, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Scale for noise added during ancestral or pingpong sampling.", | |
}, | |
), | |
"sampler_mode": ( | |
tuple(SamplerMode.__members__), | |
{ | |
"default": "EULER", | |
"tooltip": "Note: Will automatically switch to Euler when ETA is 0 or if outside of the ancestral start/end sigma range. Pingpong completely replaces all the noise, so this will have a very strong effect compared to normal ancestral sampling. It's all or nothing so you can't scale it with ETA like normal ancestral/SDE sampling.", | |
}, | |
), | |
"cfg_start_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.cfg_start_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "First sigma when the cfg_scale parameter will become activate.", | |
}, | |
), | |
"cfg_end_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.cfg_end_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "Last sigma the cfg_scale parameter is active.", | |
}, | |
), | |
"fbg_start_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.fbg_start_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "Tooltip when the fbg_scale parameter will become activate.", | |
}, | |
), | |
"fbg_end_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.fbg_end_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "Last sigma the fbg_scale parameter is active.", | |
}, | |
), | |
"ancestral_start_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.ancestral_start_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "First sigma ancestral/pingpong sampling will be active. Note: ETA must also be non-zero.", | |
}, | |
), | |
"ancestral_end_sigma": ( | |
"FLOAT", | |
{ | |
"default": defaults.ancestral_end_sigma, | |
"min": 0.0, | |
"max": 9999.0, | |
"tooltip": "Last sigma ancestral/pingpong sampling will be active.", | |
}, | |
), | |
"cfg_scale": ( | |
"FLOAT", | |
{ | |
"default": defaults.cfg_scale, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "This is a flat addition to the guidance scale and will be active in the range of cfg_start_sigma to cfg_end_sigma.", | |
}, | |
), | |
"max_guidance_scale": ( | |
"FLOAT", | |
{ | |
"default": defaults.max_guidance_scale, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Cap for the total guidance scale (both FBG and CFG if active).", | |
}, | |
), | |
"initial_guidance_scale": ( | |
"FLOAT", | |
{ | |
"default": defaults.initial_guidance_scale, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Initial value for guidance scale. Mostly useful for controling where guidance_max_change will start.", | |
}, | |
), | |
"guidance_max_change": ( | |
"FLOAT", | |
{ | |
"default": defaults.guidance_max_change, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "A percentage that can be used to limit how much the guidance scale changes per step. If you want to limit it, setting something like 0.5 here is reasonable.", | |
}, | |
), | |
"pi": ( | |
"FLOAT", | |
{ | |
"default": defaults.pi, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Parameter from the original FBG guidance. I don't know that my implementation works correctly so you need to set this fairly low to get an effect, something like 0.2.", | |
}, | |
), | |
"t_0": ( | |
"FLOAT", | |
{ | |
"default": defaults.t_0, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "If both t_0 and t_1 are 0, temp and offset values will be used.", | |
}, | |
), | |
"t_1": ( | |
"FLOAT", | |
{ | |
"default": defaults.t_1, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "If both t_0 and t_1 are 0, temp and offset values will be used.", | |
}, | |
), | |
"temp": ( | |
"FLOAT", | |
{ | |
"default": defaults.temp, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Only applies if both t_0 and t_1 are 0, otherwise this value is calculated automatically.", | |
}, | |
), | |
"offset": ( | |
"FLOAT", | |
{ | |
"default": defaults.offset, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "Only applies if both t_0 and t_1 are 0, otherwise this value is calculated automatically.", | |
}, | |
), | |
"log_posterior_initial_value": ( | |
"FLOAT", | |
{ | |
"default": defaults.initial_value, | |
"min": -1000.0, | |
"max": 1000.0, | |
"tooltip": "You most likely don't need to change this.", | |
}, | |
), | |
"fbg_guidance_multiplier": ( | |
"FLOAT", | |
{ | |
"default": defaults.fbg_guidance_multiplier, | |
"min": 0.001, | |
"max": 1000.0, | |
"tooltip": "Simple multiplier on the FBG guidance scale specifically, calculate before it is added to the combined guidance scale (which may also include CFG).", | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go( | |
cls, | |
*, | |
eta, | |
s_noise, | |
sampler_mode, | |
cfg_start_sigma, | |
cfg_end_sigma, | |
fbg_start_sigma, | |
fbg_end_sigma, | |
ancestral_start_sigma, | |
ancestral_end_sigma, | |
cfg_scale, | |
max_guidance_scale, | |
log_posterior_initial_value, | |
initial_guidance_scale, | |
guidance_max_change, | |
temp, | |
offset, | |
pi, | |
t_0, | |
t_1, | |
fbg_guidance_multiplier, | |
): | |
options = { | |
"eta": eta, | |
"s_noise": s_noise, | |
"fbg_config": FBGConfig( | |
sampler_mode=getattr(SamplerMode, sampler_mode.upper()), | |
cfg_start_sigma=cfg_start_sigma, | |
cfg_end_sigma=cfg_end_sigma, | |
fbg_start_sigma=fbg_start_sigma, | |
fbg_end_sigma=fbg_end_sigma, | |
ancestral_start_sigma=ancestral_start_sigma, | |
ancestral_end_sigma=ancestral_end_sigma, | |
cfg_scale=cfg_scale, | |
max_guidance_scale=max_guidance_scale, | |
initial_guidance_scale=initial_guidance_scale, | |
guidance_max_change=guidance_max_change, | |
temp=temp, | |
offset=offset, | |
pi=pi, | |
t_0=t_0, | |
t_1=t_1, | |
initial_value=log_posterior_initial_value, | |
fbg_guidance_multiplier=fbg_guidance_multiplier, | |
), | |
} | |
return (KSAMPLER(FBGSampler.go, extra_options=options),) | |
NODE_CLASS_MAPPINGS = { | |
"FBGSampler": FBGSamplerNode, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment