Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created September 18, 2024 21:45
Show Gist options
  • Save lucasnewman/e731c51defe42cb62de9563a2b254e88 to your computer and use it in GitHub Desktop.
Save lucasnewman/e731c51defe42cb62de9563a2b254e88 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from typing import Any, Optional
import torch
from torch import nn
import torchaudio
import yaml
from huggingface_hub import hf_hub_download
class MelSpectrogramFeatures(nn.Module):
def __init__(
self,
sample_rate=24000,
n_fft=1024,
hop_length=256,
n_mels=100,
padding="center",
):
super().__init__()
if padding != "center":
raise ValueError("Padding must be 'center'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
)
def forward(self, audio, **kwargs):
mel = self.mel_spec(audio)
features = torch.log(torch.clip(mel, min=1e-5))
return features
class ISTFT(nn.Module):
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "center"
):
super().__init__()
if padding != "center":
raise ValueError("Padding must be 'center'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
torch.hann_window(self.win_length),
center=True,
)
class ISTFTHead(nn.Module):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(self, dim: int, intermediate_dim: int, layer_scale_init_value: float):
super().__init__()
# depthwise conv
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.norm = nn.LayerNorm(dim, eps=1e-6)
# pointwise/1x1 convs, implemented with linear layers
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class VocosBackbone(nn.Module):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get("bandwidth_id", None)
x = self.embed(x)
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class Vocos(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self,
feature_extractor: MelSpectrogramFeatures,
backbone: VocosBackbone,
head: ISTFTHead,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
@classmethod
def from_hparams(cls, config_path: str) -> Vocos:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = MelSpectrogramFeatures(
**config["feature_extractor"]["init_args"]
)
backbone = VocosBackbone(**config["backbone"]["init_args"])
head = ISTFTHead(**config["head"]["init_args"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(
repo_id=repo_id, filename="config.yaml", revision=revision
)
model_path = hf_hub_download(
repo_id=repo_id, filename="pytorch_model.bin", revision=revision
)
model = cls.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features = self.feature_extractor(audio_input, **kwargs)
audio_output = self.decode(features, **kwargs)
return audio_output
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
# To use:
# vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# audio, _ = torchaudio.load("path/to/audio.wav")
# reconstructed_audio = vocos(audio)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment