Created
May 5, 2025 00:05
-
-
Save lucasnewman/56a057d8ba42f1e0d3d0b4ed3eafc1cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/mlx_audio/tts/models/spark/bicodec.py b/mlx_audio/tts/models/spark/bicodec.py | |
index 8de944b..1e2a320 100644 | |
--- a/mlx_audio/tts/models/spark/bicodec.py | |
+++ b/mlx_audio/tts/models/spark/bicodec.py | |
@@ -1,5 +1,5 @@ | |
from pathlib import Path | |
-from typing import Any, Dict | |
+from typing import Any, Dict, Optional | |
import mlx.core as mx | |
import mlx.nn as nn | |
@@ -8,7 +8,7 @@ import torch | |
from omegaconf import DictConfig | |
from safetensors.torch import load_file | |
-from mlx_audio.codec.models.vocos.mel import log_mel_spectrogram | |
+from mlx_audio.codec.models.vocos.mel import hanning, mel_filters, stft | |
from mlx_audio.tts.models.spark.modules.encoder_decoder.feat_decoder import Decoder | |
from mlx_audio.tts.models.spark.modules.encoder_decoder.feat_encoder import Encoder | |
from mlx_audio.tts.models.spark.modules.encoder_decoder.wave_generator import ( | |
@@ -19,6 +19,35 @@ from mlx_audio.tts.models.spark.modules.speaker.speaker_encoder import SpeakerEn | |
from mlx_audio.tts.models.spark.utils.file import load_config | |
from mlx_audio.tts.utils import get_model_path | |
+def log_mel_spectrogram( | |
+ audio: mx.array, | |
+ sample_rate: int = 16_000, | |
+ n_mels: int = 128, | |
+ n_fft: int = 1024, | |
+ f_min: int = 10, | |
+ f_max: Optional[int] = None, | |
+ hop_length: int = 320, | |
+ win_length: int = 640, | |
+ padding: int = 0, | |
+): | |
+ if not isinstance(audio, mx.array): | |
+ audio = mx.array(audio) | |
+ if padding > 0: | |
+ audio = mx.pad(audio, (0, padding)) | |
+ freqs = stft(audio, hanning(win_length), nperseg=n_fft, noverlap=hop_length) | |
+ magnitudes = freqs[:-1, :].abs() | |
+ filters = mel_filters( | |
+ sample_rate=sample_rate, | |
+ n_fft=n_fft, | |
+ n_mels=n_mels, | |
+ f_min=f_min, | |
+ f_max=f_max, | |
+ norm="slaney", | |
+ mel_scale="slaney", | |
+ ) | |
+ mel_spec = magnitudes @ filters.T | |
+ log_spec = mx.maximum(mel_spec, 1e-5).log() | |
+ return mx.expand_dims(log_spec, axis=0) | |
class BiCodec(nn.Module): | |
""" | |
@@ -101,7 +130,7 @@ class BiCodec(nn.Module): | |
if hasattr(module, "sanitize"): | |
weights = module.sanitize(weights) | |
- model.load_weights(list(weights.items()), strict=False) | |
+ model.load_weights(list(weights.items())) | |
return model | |
@@ -204,6 +233,9 @@ class BiCodec(nn.Module): | |
n_mels=self.mel_params["num_mels"], | |
n_fft=self.mel_params["n_fft"], | |
hop_length=self.mel_params["hop_length"], | |
+ win_length=self.mel_params["win_length"], | |
+ f_min=self.mel_params["f_min"], | |
+ f_max=self.mel_params["f_max"], | |
) | |
mels.append(mel) | |
return mx.concatenate(mels, axis=0) | |
diff --git a/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py b/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py | |
index 8de4769..4011425 100644 | |
--- a/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py | |
+++ b/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py | |
@@ -104,7 +104,7 @@ class FSQ(nn.Module): | |
if return_indices: | |
self.codebook_size = self._levels.prod().item() | |
implicit_codebook = self._indices_to_codes(mx.arange(self.codebook_size)) | |
- self.implicit_codebook = implicit_codebook | |
+ self._implicit_codebook = implicit_codebook | |
self.allowed_dtypes = allowed_dtypes | |
self.force_quantization_f32 = force_quantization_f32 | |
diff --git a/mlx_audio/tts/models/spark/modules/residual_fsq.py b/mlx_audio/tts/models/spark/modules/residual_fsq.py | |
index 21bf453..517b83d 100644 | |
--- a/mlx_audio/tts/models/spark/modules/residual_fsq.py | |
+++ b/mlx_audio/tts/models/spark/modules/residual_fsq.py | |
@@ -73,7 +73,7 @@ class ResidualFSQ(nn.Module): | |
self.codebook_size = self.layers[0].codebook_size | |
- self.scales = mx.array(scales) | |
+ self._scales = mx.array(scales) | |
self.quantize_dropout = quantize_dropout and num_quantizers > 1 | |
@@ -143,7 +143,7 @@ class ResidualFSQ(nn.Module): | |
# scale the codes | |
# Reshape scales for broadcasting: q 1 1 d | |
scales = mx.reshape( | |
- self.scales, (self.scales.shape[0], 1, 1, self.scales.shape[1]) | |
+ self._scales, (self._scales.shape[0], 1, 1, self._scales.shape[1]) | |
) | |
all_codes = all_codes * scales | |
@@ -229,7 +229,7 @@ class ResidualFSQ(nn.Module): | |
null_indices = mx.full(x.shape[:2], -1, dtype=mx.int32) | |
# go through the layers | |
- for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)): | |
+ for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self._scales)): | |
if ( | |
should_quantize_dropout |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment