Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active June 15, 2025 13:38
Show Gist options
  • Save blepping/d424e8fd27d76845ad27997820a57f6b to your computer and use it in GitHub Desktop.
Save blepping/d424e8fd27d76845ad27997820a57f6b to your computer and use it in GitHub Desktop.
Experimental FBG guidance sampler node for ComfyUI
# Referenced from https://github.com/FelixKoulischer/Feedback-Guidance-of-Diffusion-Models/
# FBG paper: https://arxiv.org/abs/2506.06085
# ComfyUI implementation by https://github.com/blepping
# LICENSE: Apache2
# Usage: Place this file in the custom_nodes directory and restart ComfyUI+refresh browser.
# It will add a FBGSampler node that can be used with SamplerCustom, etc.
import math
from enum import Enum, auto
from typing import NamedTuple
import torch
from tqdm import tqdm
from comfy import model_sampling, model_patcher
from comfy.samplers import KSAMPLER, cfg_function
from comfy.k_diffusion.sampling import get_ancestral_step
F = torch.nn.functional
class SamplerMode(Enum):
EULER = auto()
PINGPONG = auto()
class FBGConfig(NamedTuple):
sampler_mode: SamplerMode = SamplerMode.EULER
cfg_start_sigma: float = 9999.0
cfg_end_sigma: float = 0.0
fbg_start_sigma: float = 9999.0
fbg_end_sigma: float = 0.0
fbg_guidance_multiplier: float = 1.0
ancestral_start_sigma: float = 9999.0
ancestral_end_sigma: float = 0.0
cfg_scale: float = 1.0
max_guidance_scale: float = 10.0
max_posterior_scale: float = 3.0
initial_value: float = 0.0
initial_guidance_scale: float = 1.0
guidance_max_change: float = 1000.0
temp: float = 0.0
offset: float = 0.0
pi: float = 0.95
t_0: float = 0.5
t_1: float = 0.4
def batch_mse_loss(a: torch.Tensor, b: torch.Tensor, *, start_dim=1) -> torch.Tensor:
return ((a - b) ** 2).sum(dim=tuple(range(start_dim, a.ndim)))
class FBGSampler:
def __init__(
self,
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
noise_sampler=None,
eta=1.0,
s_noise=1.0,
fbg_config: FBGConfig = FBGConfig(),
**kwargs,
):
self.model_ = model
self.sigmas = sigmas
self.x = x
self.s_in = x.new_ones((x.shape[0],))
self.extra_args = extra_args if extra_args is not None else {}
self.disable = disable
self.callback_ = callback
self.config = fbg_config
self.update_config()
cfg = fbg_config
if cfg.cfg_scale > 1 and cfg.cfg_start_sigma > 0:
self.minimal_log_posterior = math.log(
(1.0 - cfg.pi)
* (cfg.max_guidance_scale - cfg.cfg_scale + 1)
/ (cfg.max_guidance_scale - cfg.cfg_scale)
)
else:
self.minimal_log_posterior = math.log(
(1.0 - cfg.pi) * cfg.max_guidance_scale / (cfg.max_guidance_scale - 1.0)
)
self.eta = max(0.0, eta)
self.s_noise = s_noise
self.is_rf = isinstance(
model.inner_model.inner_model.model_sampling, model_sampling.CONST
)
if self.eta == 0:
self.noise_sampler = None
else:
if noise_sampler is None:
def noise_sampler(*_unused):
return torch.randn_like(x)
self.noise_sampler = noise_sampler
@classmethod
def go(
cls,
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
noise_sampler=None,
eta=1.0,
s_noise=1.0,
fbg_config: FBGConfig = FBGConfig(),
**kwargs,
):
return cls(
model,
x,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
noise_sampler=noise_sampler,
eta=eta,
s_noise=s_noise,
fbg_config=fbg_config,
**kwargs,
)()
def model(
self, x: torch.Tensor, sigma: torch.Tensor, *, override_cfg=None, **kwargs: dict
) -> tuple:
sigma = sigma * self.s_in
cond = uncond = None
def post_cfg_function(args):
nonlocal cond, uncond
cond, uncond = args["cond_denoised"], args["uncond_denoised"]
return args["denoised"]
extra_args = self.extra_args.copy()
orig_model_options = extra_args.get("model_options", {})
model_options = orig_model_options.copy()
model_options["disable_cfg1_optimization"] = True
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(
model_options, post_cfg_function
)
inner_model = self.model_.inner_model
if (override_cfg is None or len(override_cfg) < 2) and hasattr(
inner_model, "cfg"
):
orig_cfg = inner_model.cfg
try:
if override_cfg is not None:
inner_model.cfg = override_cfg.detach().item()
denoised = inner_model.predict_noise(
x,
sigma,
model_options=extra_args["model_options"],
seed=extra_args.get("seed"),
)
finally:
inner_model.cfg = orig_cfg
else:
_ = self.model_(x, sigma, **extra_args, **kwargs)
denoised = cfg_function(
inner_model.inner_model,
cond,
uncond,
override_cfg,
x,
sigma,
model_options=orig_model_options,
)
return denoised, cond, uncond
def callback(self, idx, x, sigma, denoised):
if self.callback_ is None:
return
self.callback_({
"i": idx,
"x": x,
"sigma": sigma,
"sigma_hat": sigma,
"denoised": denoised,
})
def update_log_posterior(
self,
prev: torch.Tensor,
x_curr: torch.Tensor,
x_next: torch.Tensor,
t_curr: torch.Tensor,
t_next: torch.Tensor,
uncond: torch.Tensor,
cond: torch.Tensor,
) -> torch.Tensor:
t_csq = t_curr**2
t_ndc = t_next**2 / t_csq
t_cmn = t_csq - t_next**2
sigma_square_tilde_t = t_cmn * t_ndc
pred_base = t_ndc * x_curr
uncond_pred_mean = pred_base + t_cmn / t_csq * uncond
cond_pred_mean = pred_base + t_cmn / t_csq * cond
diff = batch_mse_loss(x_next, cond_pred_mean) - batch_mse_loss(
x_next, uncond_pred_mean
)
result = (
prev
- self.config.temp / (2 * sigma_square_tilde_t) * diff
+ self.config.offset
)
return result.clamp_(
self.minimal_log_posterior, self.config.max_posterior_scale
)
def get_sigma_square_tilde(self, sigmas: torch.Tensor) -> torch.Tensor:
s_sq, sn_sq = sigmas[:-1] ** 2, sigmas[1:] ** 2
return ((s_sq - sn_sq) * sn_sq / s_sq).flip(dims=(0,))
def get_offset(
self,
steps: int,
sigma_square_tilde: torch.Tensor,
*,
lambda_ref=3.0,
decimals=4,
):
cfg = self.config
result = (
1.0
/ ((1.0 - cfg.t_0) * steps)
* math.log((1.0 - cfg.pi) * lambda_ref / (lambda_ref - 1.0))
)
return round(result, decimals)
def get_temp(
self,
steps: int,
offset: float,
sigma_square_tilde: torch.Tensor,
*,
alpha=10.0,
decimals=4,
):
cfg = self.config
t1_lower = int(math.floor(cfg.t_1 * steps))
sst_t1, sst_t1_next = (
sigma_square_tilde[t1_lower],
sigma_square_tilde[t1_lower + 1],
)
a = cfg.t_1 * steps - t1_lower
sst = torch.lerp(sst_t1, sst_t1_next, a)
temp = (2 * sst / alpha * offset).abs().item()
return round(temp, decimals)
def update_config(self):
if self.config.t_0 == 0 and self.config.t_1 == 0:
return
sigmas = self.sigmas
steps = len(sigmas) - 1
sst = self.get_sigma_square_tilde(sigmas)
offset = self.get_offset(steps, sst)
temp = self.get_temp(steps, offset, sst)
kwargs = self.config._asdict() | {"offset": offset, "temp": temp}
self.config = self.config.__class__(**kwargs)
def sampler_step(self, x, denoised, sigma, sigma_next, eta):
noise_sampler = self.noise_sampler
config = self.config
if noise_sampler is None or eta == 0:
sigma_down = sigma_next
renoise_coeff = sigma_next * 0
elif self.is_rf:
# Adaptedfrom ComfyUI
downstep_ratio = 1 + (sigma_next / sigma - 1) * eta
sigma_down = sigma_next * downstep_ratio
alpha_ip1 = 1 - sigma_next
alpha_down = 1 - sigma_down
if torch.isclose(sigma_down, sigma_next):
sigma_down = sigma_next
renoise_coeff = sigma_next * 0
else:
renoise_coeff = (
sigma_next**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2
) ** 0.5
else:
sigma_down, renoise_coeff = get_ancestral_step(sigma, sigma_next, eta=eta)
if sigma_down == sigma_next:
renoise_coeff = sigma_next * 0
if noise_sampler is not None and renoise_coeff.sum() != 0:
noise = noise_sampler(sigma, sigma_next).mul_(self.s_noise)
else:
noise = None
if noise is not None and config.sampler_mode == SamplerMode.PINGPONG:
if self.is_rf:
x = torch.lerp(denoised, noise, sigma_next)
else:
x = noise.mul_(sigma_next).add_(denoised)
else:
x = torch.lerp(denoised, x, sigma_down / sigma)
if noise is not None:
noise *= renoise_coeff
if not self.is_rf:
x += noise
else:
x = x.mul_(alpha_ip1 / alpha_down).add_(noise)
return x
def get_guidance_scale(
self,
log_posterior: torch.Tensor,
guidance_scale_prev: torch.Tensor,
sigma: float,
sigma_next: float,
) -> torch.Tensor:
config = self.config
using_fbg = config.fbg_end_sigma <= sigma <= config.fbg_start_sigma
using_cfg = config.cfg_scale != 1 and (
config.cfg_end_sigma <= sigma <= config.cfg_start_sigma
)
if using_fbg:
guidance_scale = log_posterior.exp()
guidance_scale /= guidance_scale - (1.0 - config.pi)
guidance_scale *= config.fbg_guidance_multiplier
else:
guidance_scale = log_posterior.new_ones(guidance_scale_prev.shape[0])
if using_cfg:
guidance_scale += config.cfg_scale - 1.0
guidance_scale = guidance_scale.clamp(1.0, config.max_guidance_scale).view(
guidance_scale_prev.shape
)
change_diff = guidance_scale - guidance_scale_prev
change_pct = (change_diff / guidance_scale).clamp_(
-config.guidance_max_change, config.guidance_max_change
)
guidance_scale_new = guidance_scale_prev + guidance_scale_prev * change_pct
return guidance_scale_new.clamp_(1.0, config.max_guidance_scale)
def __call__(self):
x = self.x
config = self.config
orig_eta = self.eta
batch = x.shape[0]
log_posterior = x.new_full((batch,), config.initial_value)
dim_1s = (1,) * (x.ndim - 1)
sigmas = self.sigmas
guidance_scale = x.new_full((batch, *dim_1s), config.initial_guidance_scale)
with tqdm(
initial=0, total=len(sigmas) - 1, disable=self.disable, leave=True
) as pbar:
for idx, (sigma, sigma_next) in enumerate(zip(sigmas[:-1], sigmas[1:])):
x_orig = x.clone()
sigma_item, sigma_next_item = (
sigma.max().detach().item(),
sigma_next.min().detach().item(),
)
eta = (
orig_eta
if config.ancestral_end_sigma
<= sigma_item
<= config.ancestral_start_sigma
else 0.0
)
guidance_scale = self.get_guidance_scale(
log_posterior, guidance_scale, sigma_item, sigma_next_item
)
pretty_scales = ", ".join(
f"{scale:.02f}"
for scale in guidance_scale.flatten().detach().tolist()
)
pbar.set_description(
f"FBG {sigma_item:.03f} -> {sigma_next_item:.03f}, guidance: {pretty_scales}",
refresh=False,
)
if idx == 0:
pbar.refresh()
denoised, cond, uncond = self.model(
x, sigma, override_cfg=guidance_scale
)
pbar.update(1)
if sigma_next <= 1e-06:
return denoised
self.callback(idx, x, sigma, denoised)
x = self.sampler_step(x, denoised, sigma, sigma_next, eta)
log_posterior = self.update_log_posterior(
log_posterior, x_orig, x, sigma, sigma_next, uncond, cond
)
pbar.refresh()
return x
class FBGSamplerNode:
CATEGORY = "sampling/custom_sampling/samplers"
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "go"
@classmethod
def INPUT_TYPES(cls):
defaults = FBGConfig()
return {
"required": {
"eta": (
"FLOAT",
{
"default": 0.0,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Must be above zero for ancestral or pingpong sampling to activate.",
},
),
"s_noise": (
"FLOAT",
{
"default": 1.0,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Scale for noise added during ancestral or pingpong sampling.",
},
),
"sampler_mode": (
tuple(SamplerMode.__members__),
{
"default": "EULER",
"tooltip": "Note: Will automatically switch to Euler when ETA is 0 or if outside of the ancestral start/end sigma range. Pingpong completely replaces all the noise, so this will have a very strong effect compared to normal ancestral sampling. It's all or nothing so you can't scale it with ETA like normal ancestral/SDE sampling.",
},
),
"cfg_start_sigma": (
"FLOAT",
{
"default": defaults.cfg_start_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "First sigma when the cfg_scale parameter will become activate.",
},
),
"cfg_end_sigma": (
"FLOAT",
{
"default": defaults.cfg_end_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "Last sigma the cfg_scale parameter is active.",
},
),
"fbg_start_sigma": (
"FLOAT",
{
"default": defaults.fbg_start_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "Tooltip when the fbg_scale parameter will become activate.",
},
),
"fbg_end_sigma": (
"FLOAT",
{
"default": defaults.fbg_end_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "Last sigma the fbg_scale parameter is active.",
},
),
"ancestral_start_sigma": (
"FLOAT",
{
"default": defaults.ancestral_start_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "First sigma ancestral/pingpong sampling will be active. Note: ETA must also be non-zero.",
},
),
"ancestral_end_sigma": (
"FLOAT",
{
"default": defaults.ancestral_end_sigma,
"min": 0.0,
"max": 9999.0,
"tooltip": "Last sigma ancestral/pingpong sampling will be active.",
},
),
"cfg_scale": (
"FLOAT",
{
"default": defaults.cfg_scale,
"min": -1000.0,
"max": 1000.0,
"tooltip": "This is a flat addition to the guidance scale and will be active in the range of cfg_start_sigma to cfg_end_sigma.",
},
),
"max_guidance_scale": (
"FLOAT",
{
"default": defaults.max_guidance_scale,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Cap for the total guidance scale (both FBG and CFG if active).",
},
),
"initial_guidance_scale": (
"FLOAT",
{
"default": defaults.initial_guidance_scale,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Initial value for guidance scale. Mostly useful for controling where guidance_max_change will start.",
},
),
"guidance_max_change": (
"FLOAT",
{
"default": defaults.guidance_max_change,
"min": -1000.0,
"max": 1000.0,
"tooltip": "A percentage that can be used to limit how much the guidance scale changes per step. If you want to limit it, setting something like 0.5 here is reasonable.",
},
),
"pi": (
"FLOAT",
{
"default": defaults.pi,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Parameter from the original FBG guidance. I don't know that my implementation works correctly so you need to set this fairly low to get an effect, something like 0.2.",
},
),
"t_0": (
"FLOAT",
{
"default": defaults.t_0,
"min": -1000.0,
"max": 1000.0,
"tooltip": "If both t_0 and t_1 are 0, temp and offset values will be used.",
},
),
"t_1": (
"FLOAT",
{
"default": defaults.t_1,
"min": -1000.0,
"max": 1000.0,
"tooltip": "If both t_0 and t_1 are 0, temp and offset values will be used.",
},
),
"temp": (
"FLOAT",
{
"default": defaults.temp,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Only applies if both t_0 and t_1 are 0, otherwise this value is calculated automatically.",
},
),
"offset": (
"FLOAT",
{
"default": defaults.offset,
"min": -1000.0,
"max": 1000.0,
"tooltip": "Only applies if both t_0 and t_1 are 0, otherwise this value is calculated automatically.",
},
),
"log_posterior_initial_value": (
"FLOAT",
{
"default": defaults.initial_value,
"min": -1000.0,
"max": 1000.0,
"tooltip": "You most likely don't need to change this.",
},
),
"fbg_guidance_multiplier": (
"FLOAT",
{
"default": defaults.fbg_guidance_multiplier,
"min": 0.001,
"max": 1000.0,
"tooltip": "Simple multiplier on the FBG guidance scale specifically, calculate before it is added to the combined guidance scale (which may also include CFG).",
},
),
},
}
@classmethod
def go(
cls,
*,
eta,
s_noise,
sampler_mode,
cfg_start_sigma,
cfg_end_sigma,
fbg_start_sigma,
fbg_end_sigma,
ancestral_start_sigma,
ancestral_end_sigma,
cfg_scale,
max_guidance_scale,
log_posterior_initial_value,
initial_guidance_scale,
guidance_max_change,
temp,
offset,
pi,
t_0,
t_1,
fbg_guidance_multiplier,
):
options = {
"eta": eta,
"s_noise": s_noise,
"fbg_config": FBGConfig(
sampler_mode=getattr(SamplerMode, sampler_mode.upper()),
cfg_start_sigma=cfg_start_sigma,
cfg_end_sigma=cfg_end_sigma,
fbg_start_sigma=fbg_start_sigma,
fbg_end_sigma=fbg_end_sigma,
ancestral_start_sigma=ancestral_start_sigma,
ancestral_end_sigma=ancestral_end_sigma,
cfg_scale=cfg_scale,
max_guidance_scale=max_guidance_scale,
initial_guidance_scale=initial_guidance_scale,
guidance_max_change=guidance_max_change,
temp=temp,
offset=offset,
pi=pi,
t_0=t_0,
t_1=t_1,
initial_value=log_posterior_initial_value,
fbg_guidance_multiplier=fbg_guidance_multiplier,
),
}
return (KSAMPLER(FBGSampler.go, extra_options=options),)
NODE_CLASS_MAPPINGS = {
"FBGSampler": FBGSamplerNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment