Created
September 18, 2024 21:45
-
-
Save lucasnewman/e731c51defe42cb62de9563a2b254e88 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Any, Optional | |
import torch | |
from torch import nn | |
import torchaudio | |
import yaml | |
from huggingface_hub import hf_hub_download | |
class MelSpectrogramFeatures(nn.Module): | |
def __init__( | |
self, | |
sample_rate=24000, | |
n_fft=1024, | |
hop_length=256, | |
n_mels=100, | |
padding="center", | |
): | |
super().__init__() | |
if padding != "center": | |
raise ValueError("Padding must be 'center'.") | |
self.padding = padding | |
self.mel_spec = torchaudio.transforms.MelSpectrogram( | |
sample_rate=sample_rate, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
n_mels=n_mels, | |
center=True, | |
power=1, | |
) | |
def forward(self, audio, **kwargs): | |
mel = self.mel_spec(audio) | |
features = torch.log(torch.clip(mel, min=1e-5)) | |
return features | |
class ISTFT(nn.Module): | |
def __init__( | |
self, n_fft: int, hop_length: int, win_length: int, padding: str = "center" | |
): | |
super().__init__() | |
if padding != "center": | |
raise ValueError("Padding must be 'center'.") | |
self.padding = padding | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
window = torch.hann_window(win_length) | |
self.register_buffer("window", window) | |
def forward(self, spec: torch.Tensor) -> torch.Tensor: | |
return torch.istft( | |
spec, | |
self.n_fft, | |
self.hop_length, | |
self.win_length, | |
torch.hann_window(self.win_length), | |
center=True, | |
) | |
class ISTFTHead(nn.Module): | |
""" | |
ISTFT Head module for predicting STFT complex coefficients. | |
Args: | |
dim (int): Hidden dimension of the model. | |
n_fft (int): Size of Fourier transform. | |
hop_length (int): The distance between neighboring sliding window frames, which should align with | |
the resolution of the input features. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
""" | |
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): | |
super().__init__() | |
out_dim = n_fft + 2 | |
self.out = torch.nn.Linear(dim, out_dim) | |
self.istft = ISTFT( | |
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the ISTFTHead module. | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
x = self.out(x).transpose(1, 2) | |
mag, p = x.chunk(2, dim=1) | |
mag = torch.exp(mag) | |
mag = torch.clip(mag, max=1e2) | |
# wrapping happens here. These two lines produce real and imaginary value | |
x = torch.cos(p) | |
y = torch.sin(p) | |
# directly produce the complex value | |
S = mag * (x + 1j * y) | |
audio = self.istft(S) | |
return audio | |
class ConvNeXtBlock(nn.Module): | |
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. | |
Args: | |
dim (int): Number of input channels. | |
intermediate_dim (int): Dimensionality of the intermediate layer. | |
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. | |
Defaults to None. | |
""" | |
def __init__(self, dim: int, intermediate_dim: int, layer_scale_init_value: float): | |
super().__init__() | |
# depthwise conv | |
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) | |
self.norm = nn.LayerNorm(dim, eps=1e-6) | |
# pointwise/1x1 convs, implemented with linear layers | |
self.pwconv1 = nn.Linear(dim, intermediate_dim) | |
self.act = nn.GELU() | |
self.pwconv2 = nn.Linear(intermediate_dim, dim) | |
self.gamma = ( | |
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) | |
if layer_scale_init_value > 0 | |
else None | |
) | |
def forward( | |
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None | |
) -> torch.Tensor: | |
residual = x | |
x = self.dwconv(x) | |
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.pwconv2(x) | |
if self.gamma is not None: | |
x = self.gamma * x | |
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
x = residual + x | |
return x | |
class VocosBackbone(nn.Module): | |
""" | |
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization | |
Args: | |
input_channels (int): Number of input features channels. | |
dim (int): Hidden dimension of the model. | |
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. | |
num_layers (int): Number of ConvNeXtBlock layers. | |
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. | |
""" | |
def __init__( | |
self, | |
input_channels: int, | |
dim: int, | |
intermediate_dim: int, | |
num_layers: int, | |
layer_scale_init_value: Optional[float] = None, | |
): | |
super().__init__() | |
self.input_channels = input_channels | |
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) | |
self.norm = nn.LayerNorm(dim, eps=1e-6) | |
layer_scale_init_value = layer_scale_init_value or 1 / num_layers | |
self.convnext = nn.ModuleList( | |
[ | |
ConvNeXtBlock( | |
dim=dim, | |
intermediate_dim=intermediate_dim, | |
layer_scale_init_value=layer_scale_init_value, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: | |
bandwidth_id = kwargs.get("bandwidth_id", None) | |
x = self.embed(x) | |
x = self.norm(x.transpose(1, 2)) | |
x = x.transpose(1, 2) | |
for conv_block in self.convnext: | |
x = conv_block(x, cond_embedding_id=bandwidth_id) | |
x = self.final_layer_norm(x.transpose(1, 2)) | |
return x | |
class Vocos(nn.Module): | |
""" | |
The Vocos class represents a Fourier-based neural vocoder for audio synthesis. | |
This class is primarily designed for inference, with support for loading from pretrained | |
model checkpoints. It consists of three main components: a feature extractor, | |
a backbone, and a head. | |
""" | |
def __init__( | |
self, | |
feature_extractor: MelSpectrogramFeatures, | |
backbone: VocosBackbone, | |
head: ISTFTHead, | |
): | |
super().__init__() | |
self.feature_extractor = feature_extractor | |
self.backbone = backbone | |
self.head = head | |
@classmethod | |
def from_hparams(cls, config_path: str) -> Vocos: | |
""" | |
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. | |
""" | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
feature_extractor = MelSpectrogramFeatures( | |
**config["feature_extractor"]["init_args"] | |
) | |
backbone = VocosBackbone(**config["backbone"]["init_args"]) | |
head = ISTFTHead(**config["head"]["init_args"]) | |
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) | |
return model | |
@classmethod | |
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: | |
""" | |
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. | |
""" | |
config_path = hf_hub_download( | |
repo_id=repo_id, filename="config.yaml", revision=revision | |
) | |
model_path = hf_hub_download( | |
repo_id=repo_id, filename="pytorch_model.bin", revision=revision | |
) | |
model = cls.from_hparams(config_path) | |
state_dict = torch.load(model_path, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
@torch.inference_mode() | |
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: | |
""" | |
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, | |
which is then passed through the backbone and the head to reconstruct the audio output. | |
Args: | |
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), | |
where B is the batch size and L is the waveform length. | |
Returns: | |
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). | |
""" | |
features = self.feature_extractor(audio_input, **kwargs) | |
audio_output = self.decode(features, **kwargs) | |
return audio_output | |
@torch.inference_mode() | |
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: | |
""" | |
Method to decode audio waveform from already calculated features. The features input is passed through | |
the backbone and the head to reconstruct the audio output. | |
Args: | |
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, | |
C denotes the feature dimension, and L is the sequence length. | |
Returns: | |
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). | |
""" | |
x = self.backbone(features_input, **kwargs) | |
audio_output = self.head(x) | |
# To use: | |
# vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
# audio, _ = torchaudio.load("path/to/audio.wav") | |
# reconstructed_audio = vocos(audio) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment