Last active
May 27, 2025 02:11
-
-
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By https://github.com/blepping | |
# License: Apache2 | |
# | |
# Place this file in your custom_nodes directory and it should load automatically. | |
import math | |
import torch | |
SILENCE = torch.tensor(( | |
(-0.6462, -1.2132, -1.3026, -1.2432, -1.2455, -1.2162, -1.2184, -1.2114, -1.2153, -1.2144, -1.2130, -1.2115, -1.2063, -1.1918, -1.1154, -0.7924), | |
( 0.0473, -0.3690, -0.6507, -0.5677, -0.6139, -0.5863, -0.5783, -0.5746, -0.5748, -0.5763, -0.5774, -0.5760, -0.5714, -0.5560, -0.5393, -0.3263), | |
(-1.3019, -1.9225, -2.0812, -2.1188, -2.1298, -2.1227, -2.1080, -2.1133, -2.1096, -2.1077, -2.1118, -2.1141, -2.1168, -2.1134, -2.0720, -1.7442), | |
(-4.4184, -5.5253, -5.7387, -5.7961, -5.7819, -5.7850, -5.7980, -5.8083, -5.8197, -5.8202, -5.8231, -5.8305, -5.8313, -5.8153, -5.6875, -4.7317), | |
( 1.5986, 2.0669, 2.0660, 2.0476, 2.0330, 2.0271, 2.0252, 2.0268, 2.0289, 2.0260, 2.0261, 2.0252, 2.0240, 2.0220, 1.9828, 1.6429), | |
(-0.4177, -0.9632, -1.0095, -1.0597, -1.0462, -1.0640, -1.0607, -1.0604, -1.0641, -1.0636, -1.0631, -1.0594, -1.0555, -1.0466, -1.0139, -0.8284), | |
(-0.7686, -1.0507, -1.3932, -1.4880, -1.5199, -1.5377, -1.5333, -1.5320, -1.5307, -1.5319, -1.5360, -1.5383, -1.5398, -1.5381, -1.4961, -1.1732), | |
( 0.0199, -0.0880, -0.4010, -0.3936, -0.4219, -0.4026, -0.3907, -0.3940, -0.3961, -0.3947, -0.3941, -0.3929, -0.3889, -0.3741, -0.3432, -0.169), | |
), dtype=torch.float32, device="cpu")[None, ..., None] | |
def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)): | |
min_val, max_val = ( | |
latent.amin(dim=dim, keepdim=True), | |
latent.amax(dim=dim, keepdim=True), | |
) | |
normalized = (latent - min_val).div_(max_val - min_val) | |
return ( | |
normalized.mul_(target_max - target_min) | |
.add_(target_min) | |
.clamp_(target_min, target_max) | |
) | |
TEMPORAL_SCALE_FACTOR = 44100 / 512 / 8 | |
class SilentLatentNode: | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("LATENT",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1, "tooltip": "Number of seconds to generate. Ignored if optional latent input is connected."}), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "Batch size to generate. Ignored if optional latent input is connected."}), | |
}, | |
"optional": { | |
"ref_latent_opt": ("LATENT", {"tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference."}), | |
}, | |
} | |
@classmethod | |
def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> dict: | |
if ref_latent_opt is not None: | |
latent = torch.zeros(ref_latent_opt["samples"].shape, device="cpu", dtype=torch.float32) | |
else: | |
length = int(seconds * TEMPORAL_SCALE_FACTOR) | |
latent = torch.zeros(batch_size, 8, 16, length, device="cpu", dtype=torch.float32) | |
latent += SILENCE | |
return ({"samples": latent, "type": "audio"},) | |
class VisualizeLatentNode: | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("IMAGE",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"latent": ("LATENT",), | |
"scale_secs": ( | |
"INT", | |
{ | |
"default": 0, "min": 0, "max": 1000, | |
"tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.", | |
}, | |
), | |
"scale_vertical": ( | |
"INT", | |
{ | |
"default": 1, | |
"min": 1, | |
"max": 1024, | |
"tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).", | |
}, | |
), | |
"swap_channels_freqs": ( | |
"BOOLEAN", | |
{ | |
"default": False, | |
"tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.", | |
}, | |
), | |
"normalize_dims": ( | |
"STRING", | |
{ | |
"default": "-1", | |
"tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.", | |
}, | |
), | |
"mode": ( | |
("split", "combined", "brg", "rgb", "bgr", "split_flip", "combined_flip", "brg_flip", "rgb_flip", "bgr_flip"), { | |
"default": "split", | |
"tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.", | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go(cls, *, latent, scale_secs, scale_vertical, swap_channels_freqs, normalize_dims, mode) -> tuple: | |
normalize_dims = normalize_dims.strip() | |
normalize_dims = () if not normalize_dims else tuple(int(dim) for dim in normalize_dims.split(",")) | |
samples = latent["samples"].to(dtype=torch.float32, device="cpu") | |
if samples.ndim != 4: | |
raise ValueError("Expected an ACE-Steps latent with 4 dimensions") | |
color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"} | |
batch, channels, freqs, temporal = samples.shape | |
samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims) | |
if mode.endswith("_flip"): | |
samples = 1.0 - samples | |
if swap_channels_freqs: | |
samples = samples.movedim(2, 1) | |
if mode.startswith("combined"): | |
samples = samples.mean(dim=1, keepdim=True) | |
if scale_vertical != 1: | |
samples = samples.repeat_interleave(scale_vertical, dim=2) | |
if not color_mode: | |
samples = samples.reshape(batch, -1, temporal) | |
if scale_secs > 0: | |
new_temporal = round((temporal / TEMPORAL_SCALE_FACTOR) * scale_secs) | |
samples = torch.nn.functional.interpolate( | |
samples.unsqueeze(1) if not color_mode else samples, | |
size=(samples.shape[-2], new_temporal), | |
mode="nearest-exact", | |
) | |
if not color_mode: | |
samples = samples.squeeze(1) | |
if not color_mode: | |
return (samples[..., None].expand(*samples.shape, 3),) | |
rgb_count = math.ceil(samples.shape[1] / 3) | |
channels_pad = rgb_count * 3 - samples.shape[1] | |
samples = torch.cat((samples, samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:])), dim=1) | |
samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1) | |
if mode.startswith("bgr"): | |
samples = samples.flip(-1) | |
elif mode.startswith("brg"): | |
samples = samples.roll(-1, -1) | |
return (samples,) | |
class SplitOutLyricsNode: | |
DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("CONDITIONING","CONDITIONING_ACE_LYRICS") | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"conditioning": ("CONDITIONING",), | |
"add_fake_pooled": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
@classmethod | |
def go(cls, *, conditioning, add_fake_pooled) -> dict: | |
tags_result, lyrics_result = [], [] | |
for cond_t, cond_d in conditioning: | |
cond_d = cond_d.copy() | |
cond_lyr = cond_d.pop("conditioning_lyrics", None) | |
cond_lyrstr = cond_d.pop("lyrics_strength", None) | |
if add_fake_pooled: | |
cond_d["pooled_output"] = cond_t.new_zeros(1, 1) | |
tags_result.append([cond_t.clone(), cond_d]) | |
lyrics_result.append({"conditioning_lyrics": cond_lyr.clone(), "lyrics_strength": cond_lyrstr}) | |
return (tags_result, lyrics_result) | |
class JoinLyricsNode: | |
DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("CONDITIONING",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"conditioning_tags": ("CONDITIONING",), | |
"conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",), | |
}, | |
} | |
@classmethod | |
def go(cls, *, conditioning_tags, conditioning_lyrics) -> dict: | |
ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics) | |
if ct_len != cl_len: | |
raise ValueError(f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}") | |
if ct_len > 0 and conditioning_lyrics[0].get("conditioning_lyrics") is None: | |
raise ValueError("conditioning_lyrics missing items, cannot combine with it.") | |
result = [ | |
[ | |
cond_t.clone(), | |
cond_d.copy() | { | |
"conditioning_lyrics": cond_l["conditioning_lyrics"].clone(), | |
"lyrics_strength": cond_l["lyrics_strength"], | |
"pooled_output": None, | |
}, | |
] | |
for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics) | |
] | |
return (result,) | |
NODE_CLASS_MAPPINGS = { | |
"ACETricks SilentLatent": SilentLatentNode, | |
"ACETricks VisualizeLatent": VisualizeLatentNode, | |
"ACETricks CondSplitOutLyrics": SplitOutLyricsNode, | |
"ACETricks CondJoinLyrics": JoinLyricsNode, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment