Created
February 28, 2025 18:45
-
-
Save lucasnewman/c57c75eb66a9f0c12a8b80f1cf955421 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
import librosa | |
import mlx.core as mx | |
import numpy as np | |
@lru_cache(maxsize=None) | |
def hanning(size): | |
""" | |
Compute the Hanning window. | |
Args: | |
size: Size of the window. | |
Returns: | |
mx.array of shape (size,) containing the Hanning window. | |
""" | |
return mx.array(np.hanning(size + 1)[:-1]) | |
def stft( | |
x, | |
window, | |
nperseg=256, | |
noverlap=None, | |
nfft=None, | |
pad_mode="constant", | |
): | |
if nfft is None: | |
nfft = nperseg | |
if noverlap is None: | |
noverlap = nfft // 4 | |
def _pad(x, padding, pad_mode="constant"): | |
if pad_mode == "constant": | |
return mx.pad(x, [(padding, padding)]) | |
elif pad_mode == "reflect": | |
prefix = x[1 : padding + 1][::-1] | |
suffix = x[-(padding + 1) : -1][::-1] | |
return mx.concatenate([prefix, x, suffix]) | |
else: | |
raise ValueError(f"Invalid pad_mode {pad_mode}") | |
padding = nperseg // 2 | |
x = _pad(x, padding, pad_mode) | |
strides = [noverlap, 1] | |
t = (x.size - nperseg + noverlap) // noverlap | |
shape = [t, nfft] | |
x = mx.as_strided(x, shape=shape, strides=strides) | |
return mx.fft.rfft(x * window) | |
# generate random audio-shaped data | |
x = mx.random.uniform(shape=(44100,)) | |
out = stft(x, hanning(256)).transpose(1, 0) | |
print(out.shape) | |
out2 = librosa.stft(np.array(x), n_fft=256, hop_length=256 // 4) | |
print(out2.shape) | |
print(np.array(out)) | |
print("----") | |
print(out2) | |
mx.allclose(out, mx.array(out2), rtol=1e-5, atol=1e-5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment