Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active May 28, 2025 00:26
Show Gist options
  • Save blepping/b372ef6c5412080af136aad942d9d76c to your computer and use it in GitHub Desktop.
Save blepping/b372ef6c5412080af136aad942d9d76c to your computer and use it in GitHub Desktop.
PingPong sampler for ComfyUI
# By https://github.com/blepping
# LICENSE: Apache2
# Usage: Place this file in the custom_nodes directory and restart ComfyUI+refresh browser.
# It will add a PingPongSampler node that can be used with SamplerCustom, etc.
import random
import torch
from tqdm.auto import trange
from comfy import model_sampling
from comfy.samplers import KSAMPLER
import nodes
BLEND_MODES = None
def _ensure_blend_modes():
global BLEND_MODES
if BLEND_MODES is not None:
return
bleh = getattr(nodes, "_blepping_integrations", {}).get("bleh")
if bleh is not None:
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES
else:
BLEND_MODES = {"lerp": torch.lerp, "a_only": lambda a, _b, _t: a, "b_only": lambda _a, b, _t: b}
class ModelProxy:
def __init__(self, model, last_x, last_sigma, last_denoised):
self.__model = model
self.__last_x = last_x
self.__last_sigma = last_sigma
self.__last_denoised = last_denoised
def __call__(self, x, sigma, *args, **kwargs):
if torch.allclose(sigma.to(self.__last_sigma), self.__last_sigma) and torch.allclose(x.to(self.__last_x), self.__last_x):
return self.__last_denoised.to(x, copy=True)
return self.__model(x, sigma, *args, **kwargs)
def __getattr__(self, k):
return getattr(self.__model, k)
class PingPongSampler:
def __init__(self, model, x, sigmas, *args, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, pingpong_options=None, **kwargs):
self.args = args
self.kwargs = kwargs
self.model_ = model
self.sigmas = sigmas
self.x = x
self.s_in = x.new_ones((x.shape[0],))
self.extra_args = extra_args.copy() if extra_args is not None else {}
self.seed = self.extra_args.pop("seed", 42)
self.disable = disable
self.callback_ = callback
if pingpong_options is None:
pingpong_options= {}
self.first_ancestral_step = pingpong_options.get("first_ancestral_step", 0)
self.last_ancestral_step = pingpong_options.get("last_ancestral_step", 0)
self.pingpong_blend = pingpong_options.get("pingpong_blend")
sampler_opt = pingpong_options.get("external_sampler")
if self.pingpong_blend != 1.0 and sampler_opt is None:
raise ValueError("Sampler input must be connect when pingpong_blend isn't 1.0")
self.external_sampler = sampler_opt
self.step_blend_function = pingpong_options.get("step_blend_function", torch.lerp)
self.blend_function = pingpong_options.get("blend_function", torch.lerp)
self.s_noise = s_noise
self.is_rf = isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST)
if noise_sampler is None:
def noise_sampler(*_unused):
return torch.randn_like(x)
self.noise_sampler = noise_sampler
@classmethod
def go(cls, model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, pingpong_options=None, **kwargs):
return cls(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, noise_sampler=noise_sampler, s_noise=s_noise, pingpong_options=pingpong_options, **kwargs)()
def model(self, x, sigma, **kwargs):
return self.model_(x, sigma * self.s_in, **self.extra_args, **kwargs)
def callback(self, idx, x, sigma, denoised):
if self.callback_ is None:
return
self.callback_({
"i": idx,
"x": x,
"sigma": sigma,
"sigma_hat": sigma,
"denoised": denoised,
})
def __call__(self):
x = self.x
noise_sampler = self.noise_sampler
astart_step = self.first_ancestral_step
aend_step = self.last_ancestral_step
last_step_idx = len(self.sigmas) - 2
step_count = len(self.sigmas) - 1
if astart_step < 0:
astart_step = step_count + astart_step
if aend_step < 0:
aend_step = step_count + aend_step
astart_step = min(last_step_idx, max(0, astart_step))
aend_step = min(last_step_idx, max(0, aend_step))
s_noise = self.s_noise
seed_offset = 10
for idx in trange(step_count, disable=self.disable):
sigma, sigma_next = self.sigmas[idx:idx + 2]
orig_x = x
denoised = self.model(orig_x, sigma)
self.callback(idx, x, sigma, denoised)
use_ancestral = astart_step <= idx <= aend_step
if sigma_next <= 1e-06:
return denoised
if not use_ancestral:
x = self.step_blend_function(denoised, x, sigma_next / sigma)
continue
if self.pingpong_blend != 1.0:
alt_x = self.external_sampler.sampler_function(
ModelProxy(self.model_, x, sigma, denoised),
orig_x.clone(),
self.sigmas[idx:idx + 2].clone(),
*self.args,
disable=True,
callback=None,
extra_args=self.extra_args | {"seed": self.seed + seed_offset},
**self.external_sampler.extra_options,
**self.kwargs,
)
seed_offset += 10
if self.pingpong_blend <= 0:
x = alt_x
continue
noise = noise_sampler(sigma, sigma_next).mul_(self.s_noise)
if self.is_rf:
x = self.step_blend_function(denoised, noise, sigma_next)
else:
x = denoised + noise * sigma_next
if self.pingpong_blend != 1.0:
x = self.blend_function(alt_x, x, self.pingpong_blend)
del alt_x
return x
class PingPongSamplerNode:
CATEGORY = "sampling/custom_sampling/samplers"
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "go"
@classmethod
def INPUT_TYPES(cls):
_ensure_blend_modes()
return {
"required": {
"s_noise": ("FLOAT", {"default": 1.0, "min": -1000.0, "max": 1000.0}),
"first_ancestral_step": ("INT", {"default": 0, "min": -10000, "max": 10000}),
"last_ancestral_step": ("INT", {"default": -1, "min": -10000, "max": 10000}),
"pingpong_blend": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Allows blending pingpong sampling with a different sampler. Only has an effect during the ancestral_step range. If set to a value below 1.0 (100% pingpong) then sampler_opt must be attached."}),
"blend_mode": (tuple(BLEND_MODES), {"default": "lerp", "tooltip": "Blend mode to use when blending pingpong sampling with the external sampler. See tooltip for pingpong_blend. Can integrate with ComfyUI-bleh to add more blend modes."}),
"step_blend_mode": (tuple(BLEND_MODES), {"default": "lerp", "tooltip": "Blend mode to use for pingpong steps. Changing this is likely a bad idea. Does not apply for ancestral steps on non-flow models. Can integrate with ComfyUI-bleh to add more blend modes."}),
},
"optional": {
"sampler_opt": ("SAMPLER", {"tooltip": "Optional when pingpong_blend is 1.0. Result of a pingpong step will be blended with output from this sampler with the configured ratio. Calls the sampler on a single step so will not work well with samplers that care about state (I.E. history samplers such as deis, res_multistep, etc)."}),
},
}
@classmethod
def go(cls, *, s_noise: float, first_ancestral_step: int, last_ancestral_step: int, pingpong_blend: float, blend_mode: str, step_blend_mode: str, sampler_opt = None):
options = {
"s_noise": s_noise,
"pingpong_options": {
"first_ancestral_step": first_ancestral_step,
"last_ancestral_step": last_ancestral_step,
"pingpong_blend": pingpong_blend,
"blend_function": BLEND_MODES[blend_mode],
"step_blend_function": BLEND_MODES[step_blend_mode],
"external_sampler": sampler_opt,
},
}
return (KSAMPLER(PingPongSampler.go, extra_options=options),)
class RestlessSchedulerNode:
DESCRIPTION = "HACK: A weird scheduler that will randomly jump around a list of sigmas you input. Not recommended. Breaks most multi-step and history samplers. Works okay-ish with Pingpong."
CATEGORY = "sampling/custom_sampling/schedulers"
RETURN_TYPES = ("SIGMAS",)
FUNCTION = "go"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"sigmas": ("SIGMAS",),
"seed": (
"INT",
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "Seed to use for generating schedule.",
},
),
"shrink_factor": ("FLOAT", {
"default": 0.3,
"tooltip": "Amount the window for restless scheduling shrinks by per iteration.",
}),
"first_restless_step": ("INT", {
"default": 3, "min": 1,
"tooltip": "First step (0-based) to include for restless scheduling. Must be greater than 1 and less than last_restless_step.",
}),
"last_restless_step": ("INT", {
"default": -4, "min": -10000, "max": 10000,
"tooltip": "Last step (0-based) to include for restless scheduling. Can be negative to count from the end, but you cannot target the last sigma in the list.",
}),
},
}
@classmethod
def go(cls, *, sigmas: torch.Tensor, seed: int, shrink_factor: float, first_restless_step: int, last_restless_step: int) -> tuple:
n_sigmas = len(sigmas)
if n_sigmas < 3:
return (sigmas,)
if last_restless_step < 0:
last_restless_step = n_sigmas + last_restless_step
if last_restless_step <= first_restless_step:
raise ValueError("Last restless step <= first restless step!")
if last_restless_step >= n_sigmas - 1:
raise ValueError("Last restless step cannot include the final sigma")
orig_sigmas = sigmas
random.seed(seed)
result = sigmas[:first_restless_step].tolist()
end_chunk = sigmas[last_restless_step + 1:].tolist()
sigmas = sigmas[first_restless_step:last_restless_step + 1].tolist()
n_sigmas = len(sigmas)
shrinkage = 0.0
curr_idx = None
while (window_size := int((n_sigmas - 1) - shrinkage)) > 0:
next_idx = random.randint(0, window_size)
if next_idx == curr_idx:
next_idx += 1
result.append(sigmas[int(shrinkage) + next_idx])
curr_idx = next_idx
shrinkage += shrink_factor
result += end_chunk
return (torch.tensor(result, dtype=torch.float32, device="cpu").to(orig_sigmas),)
NODE_CLASS_MAPPINGS = {
"PingPongSampler": PingPongSamplerNode,
"RestlessScheduler": RestlessSchedulerNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment