Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created May 5, 2025 00:05
Show Gist options
  • Save lucasnewman/56a057d8ba42f1e0d3d0b4ed3eafc1cf to your computer and use it in GitHub Desktop.
Save lucasnewman/56a057d8ba42f1e0d3d0b4ed3eafc1cf to your computer and use it in GitHub Desktop.
diff --git a/mlx_audio/tts/models/spark/bicodec.py b/mlx_audio/tts/models/spark/bicodec.py
index 8de944b..1e2a320 100644
--- a/mlx_audio/tts/models/spark/bicodec.py
+++ b/mlx_audio/tts/models/spark/bicodec.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -8,7 +8,7 @@ import torch
from omegaconf import DictConfig
from safetensors.torch import load_file
-from mlx_audio.codec.models.vocos.mel import log_mel_spectrogram
+from mlx_audio.codec.models.vocos.mel import hanning, mel_filters, stft
from mlx_audio.tts.models.spark.modules.encoder_decoder.feat_decoder import Decoder
from mlx_audio.tts.models.spark.modules.encoder_decoder.feat_encoder import Encoder
from mlx_audio.tts.models.spark.modules.encoder_decoder.wave_generator import (
@@ -19,6 +19,35 @@ from mlx_audio.tts.models.spark.modules.speaker.speaker_encoder import SpeakerEn
from mlx_audio.tts.models.spark.utils.file import load_config
from mlx_audio.tts.utils import get_model_path
+def log_mel_spectrogram(
+ audio: mx.array,
+ sample_rate: int = 16_000,
+ n_mels: int = 128,
+ n_fft: int = 1024,
+ f_min: int = 10,
+ f_max: Optional[int] = None,
+ hop_length: int = 320,
+ win_length: int = 640,
+ padding: int = 0,
+):
+ if not isinstance(audio, mx.array):
+ audio = mx.array(audio)
+ if padding > 0:
+ audio = mx.pad(audio, (0, padding))
+ freqs = stft(audio, hanning(win_length), nperseg=n_fft, noverlap=hop_length)
+ magnitudes = freqs[:-1, :].abs()
+ filters = mel_filters(
+ sample_rate=sample_rate,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ f_min=f_min,
+ f_max=f_max,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ mel_spec = magnitudes @ filters.T
+ log_spec = mx.maximum(mel_spec, 1e-5).log()
+ return mx.expand_dims(log_spec, axis=0)
class BiCodec(nn.Module):
"""
@@ -101,7 +130,7 @@ class BiCodec(nn.Module):
if hasattr(module, "sanitize"):
weights = module.sanitize(weights)
- model.load_weights(list(weights.items()), strict=False)
+ model.load_weights(list(weights.items()))
return model
@@ -204,6 +233,9 @@ class BiCodec(nn.Module):
n_mels=self.mel_params["num_mels"],
n_fft=self.mel_params["n_fft"],
hop_length=self.mel_params["hop_length"],
+ win_length=self.mel_params["win_length"],
+ f_min=self.mel_params["f_min"],
+ f_max=self.mel_params["f_max"],
)
mels.append(mel)
return mx.concatenate(mels, axis=0)
diff --git a/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py b/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py
index 8de4769..4011425 100644
--- a/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py
+++ b/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py
@@ -104,7 +104,7 @@ class FSQ(nn.Module):
if return_indices:
self.codebook_size = self._levels.prod().item()
implicit_codebook = self._indices_to_codes(mx.arange(self.codebook_size))
- self.implicit_codebook = implicit_codebook
+ self._implicit_codebook = implicit_codebook
self.allowed_dtypes = allowed_dtypes
self.force_quantization_f32 = force_quantization_f32
diff --git a/mlx_audio/tts/models/spark/modules/residual_fsq.py b/mlx_audio/tts/models/spark/modules/residual_fsq.py
index 21bf453..517b83d 100644
--- a/mlx_audio/tts/models/spark/modules/residual_fsq.py
+++ b/mlx_audio/tts/models/spark/modules/residual_fsq.py
@@ -73,7 +73,7 @@ class ResidualFSQ(nn.Module):
self.codebook_size = self.layers[0].codebook_size
- self.scales = mx.array(scales)
+ self._scales = mx.array(scales)
self.quantize_dropout = quantize_dropout and num_quantizers > 1
@@ -143,7 +143,7 @@ class ResidualFSQ(nn.Module):
# scale the codes
# Reshape scales for broadcasting: q 1 1 d
scales = mx.reshape(
- self.scales, (self.scales.shape[0], 1, 1, self.scales.shape[1])
+ self._scales, (self._scales.shape[0], 1, 1, self._scales.shape[1])
)
all_codes = all_codes * scales
@@ -229,7 +229,7 @@ class ResidualFSQ(nn.Module):
null_indices = mx.full(x.shape[:2], -1, dtype=mx.int32)
# go through the layers
- for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):
+ for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self._scales)):
if (
should_quantize_dropout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment