Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created February 28, 2025 18:45
Show Gist options
  • Save lucasnewman/c57c75eb66a9f0c12a8b80f1cf955421 to your computer and use it in GitHub Desktop.
Save lucasnewman/c57c75eb66a9f0c12a8b80f1cf955421 to your computer and use it in GitHub Desktop.
from functools import lru_cache
import librosa
import mlx.core as mx
import numpy as np
@lru_cache(maxsize=None)
def hanning(size):
"""
Compute the Hanning window.
Args:
size: Size of the window.
Returns:
mx.array of shape (size,) containing the Hanning window.
"""
return mx.array(np.hanning(size + 1)[:-1])
def stft(
x,
window,
nperseg=256,
noverlap=None,
nfft=None,
pad_mode="constant",
):
if nfft is None:
nfft = nperseg
if noverlap is None:
noverlap = nfft // 4
def _pad(x, padding, pad_mode="constant"):
if pad_mode == "constant":
return mx.pad(x, [(padding, padding)])
elif pad_mode == "reflect":
prefix = x[1 : padding + 1][::-1]
suffix = x[-(padding + 1) : -1][::-1]
return mx.concatenate([prefix, x, suffix])
else:
raise ValueError(f"Invalid pad_mode {pad_mode}")
padding = nperseg // 2
x = _pad(x, padding, pad_mode)
strides = [noverlap, 1]
t = (x.size - nperseg + noverlap) // noverlap
shape = [t, nfft]
x = mx.as_strided(x, shape=shape, strides=strides)
return mx.fft.rfft(x * window)
# generate random audio-shaped data
x = mx.random.uniform(shape=(44100,))
out = stft(x, hanning(256)).transpose(1, 0)
print(out.shape)
out2 = librosa.stft(np.array(x), n_fft=256, hop_length=256 // 4)
print(out2.shape)
print(np.array(out))
print("----")
print(out2)
mx.allclose(out, mx.array(out2), rtol=1e-5, atol=1e-5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment