Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active May 27, 2025 02:11
Show Gist options
  • Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE
# By https://github.com/blepping
# License: Apache2
#
# Place this file in your custom_nodes directory and it should load automatically.
import math
import torch
SILENCE = torch.tensor((
(-0.6462, -1.2132, -1.3026, -1.2432, -1.2455, -1.2162, -1.2184, -1.2114, -1.2153, -1.2144, -1.2130, -1.2115, -1.2063, -1.1918, -1.1154, -0.7924),
( 0.0473, -0.3690, -0.6507, -0.5677, -0.6139, -0.5863, -0.5783, -0.5746, -0.5748, -0.5763, -0.5774, -0.5760, -0.5714, -0.5560, -0.5393, -0.3263),
(-1.3019, -1.9225, -2.0812, -2.1188, -2.1298, -2.1227, -2.1080, -2.1133, -2.1096, -2.1077, -2.1118, -2.1141, -2.1168, -2.1134, -2.0720, -1.7442),
(-4.4184, -5.5253, -5.7387, -5.7961, -5.7819, -5.7850, -5.7980, -5.8083, -5.8197, -5.8202, -5.8231, -5.8305, -5.8313, -5.8153, -5.6875, -4.7317),
( 1.5986, 2.0669, 2.0660, 2.0476, 2.0330, 2.0271, 2.0252, 2.0268, 2.0289, 2.0260, 2.0261, 2.0252, 2.0240, 2.0220, 1.9828, 1.6429),
(-0.4177, -0.9632, -1.0095, -1.0597, -1.0462, -1.0640, -1.0607, -1.0604, -1.0641, -1.0636, -1.0631, -1.0594, -1.0555, -1.0466, -1.0139, -0.8284),
(-0.7686, -1.0507, -1.3932, -1.4880, -1.5199, -1.5377, -1.5333, -1.5320, -1.5307, -1.5319, -1.5360, -1.5383, -1.5398, -1.5381, -1.4961, -1.1732),
( 0.0199, -0.0880, -0.4010, -0.3936, -0.4219, -0.4026, -0.3907, -0.3940, -0.3961, -0.3947, -0.3941, -0.3929, -0.3889, -0.3741, -0.3432, -0.169),
), dtype=torch.float32, device="cpu")[None, ..., None]
def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)):
min_val, max_val = (
latent.amin(dim=dim, keepdim=True),
latent.amax(dim=dim, keepdim=True),
)
normalized = (latent - min_val).div_(max_val - min_val)
return (
normalized.mul_(target_max - target_min)
.add_(target_min)
.clamp_(target_min, target_max)
)
TEMPORAL_SCALE_FACTOR = 44100 / 512 / 8
class SilentLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("LATENT",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1, "tooltip": "Number of seconds to generate. Ignored if optional latent input is connected."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "Batch size to generate. Ignored if optional latent input is connected."}),
},
"optional": {
"ref_latent_opt": ("LATENT", {"tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference."}),
},
}
@classmethod
def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> dict:
if ref_latent_opt is not None:
latent = torch.zeros(ref_latent_opt["samples"].shape, device="cpu", dtype=torch.float32)
else:
length = int(seconds * TEMPORAL_SCALE_FACTOR)
latent = torch.zeros(batch_size, 8, 16, length, device="cpu", dtype=torch.float32)
latent += SILENCE
return ({"samples": latent, "type": "audio"},)
class VisualizeLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"latent": ("LATENT",),
"scale_secs": (
"INT",
{
"default": 0, "min": 0, "max": 1000,
"tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.",
},
),
"scale_vertical": (
"INT",
{
"default": 1,
"min": 1,
"max": 1024,
"tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).",
},
),
"swap_channels_freqs": (
"BOOLEAN",
{
"default": False,
"tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.",
},
),
"normalize_dims": (
"STRING",
{
"default": "-1",
"tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.",
},
),
"mode": (
("split", "combined", "brg", "rgb", "bgr", "split_flip", "combined_flip", "brg_flip", "rgb_flip", "bgr_flip"), {
"default": "split",
"tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.",
},
),
},
}
@classmethod
def go(cls, *, latent, scale_secs, scale_vertical, swap_channels_freqs, normalize_dims, mode) -> tuple:
normalize_dims = normalize_dims.strip()
normalize_dims = () if not normalize_dims else tuple(int(dim) for dim in normalize_dims.split(","))
samples = latent["samples"].to(dtype=torch.float32, device="cpu")
if samples.ndim != 4:
raise ValueError("Expected an ACE-Steps latent with 4 dimensions")
color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"}
batch, channels, freqs, temporal = samples.shape
samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims)
if mode.endswith("_flip"):
samples = 1.0 - samples
if swap_channels_freqs:
samples = samples.movedim(2, 1)
if mode.startswith("combined"):
samples = samples.mean(dim=1, keepdim=True)
if scale_vertical != 1:
samples = samples.repeat_interleave(scale_vertical, dim=2)
if not color_mode:
samples = samples.reshape(batch, -1, temporal)
if scale_secs > 0:
new_temporal = round((temporal / TEMPORAL_SCALE_FACTOR) * scale_secs)
samples = torch.nn.functional.interpolate(
samples.unsqueeze(1) if not color_mode else samples,
size=(samples.shape[-2], new_temporal),
mode="nearest-exact",
)
if not color_mode:
samples = samples.squeeze(1)
if not color_mode:
return (samples[..., None].expand(*samples.shape, 3),)
rgb_count = math.ceil(samples.shape[1] / 3)
channels_pad = rgb_count * 3 - samples.shape[1]
samples = torch.cat((samples, samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:])), dim=1)
samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1)
if mode.startswith("bgr"):
samples = samples.flip(-1)
elif mode.startswith("brg"):
samples = samples.roll(-1, -1)
return (samples,)
class SplitOutLyricsNode:
DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING","CONDITIONING_ACE_LYRICS")
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning": ("CONDITIONING",),
"add_fake_pooled": ("BOOLEAN", {"default": True}),
},
}
@classmethod
def go(cls, *, conditioning, add_fake_pooled) -> dict:
tags_result, lyrics_result = [], []
for cond_t, cond_d in conditioning:
cond_d = cond_d.copy()
cond_lyr = cond_d.pop("conditioning_lyrics", None)
cond_lyrstr = cond_d.pop("lyrics_strength", None)
if add_fake_pooled:
cond_d["pooled_output"] = cond_t.new_zeros(1, 1)
tags_result.append([cond_t.clone(), cond_d])
lyrics_result.append({"conditioning_lyrics": cond_lyr.clone(), "lyrics_strength": cond_lyrstr})
return (tags_result, lyrics_result)
class JoinLyricsNode:
DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning_tags": ("CONDITIONING",),
"conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",),
},
}
@classmethod
def go(cls, *, conditioning_tags, conditioning_lyrics) -> dict:
ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics)
if ct_len != cl_len:
raise ValueError(f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}")
if ct_len > 0 and conditioning_lyrics[0].get("conditioning_lyrics") is None:
raise ValueError("conditioning_lyrics missing items, cannot combine with it.")
result = [
[
cond_t.clone(),
cond_d.copy() | {
"conditioning_lyrics": cond_l["conditioning_lyrics"].clone(),
"lyrics_strength": cond_l["lyrics_strength"],
"pooled_output": None,
},
]
for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics)
]
return (result,)
NODE_CLASS_MAPPINGS = {
"ACETricks SilentLatent": SilentLatentNode,
"ACETricks VisualizeLatent": VisualizeLatentNode,
"ACETricks CondSplitOutLyrics": SplitOutLyricsNode,
"ACETricks CondJoinLyrics": JoinLyricsNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment