Skip to content

Instantly share code, notes, and snippets.

@lucasnewman
Created April 1, 2025 20:40
Show Gist options
  • Save lucasnewman/d89fcd7b45c9d21af6081eaa733b77ac to your computer and use it in GitHub Desktop.
Save lucasnewman/d89fcd7b45c9d21af6081eaa733b77ac to your computer and use it in GitHub Desktop.
import mlx.core as mx
def istft(
x: mx.array, # (freq_bins, num_frames)
window: mx.array,
hop_length: int = 256,
win_length: int = 1024,
):
num_frames = x.shape[1]
total_len = (num_frames - 1) * hop_length + win_length
reconstructed = mx.zeros(t)
window_sum = mx.zeros(t)
# inverse FFT of each frame
y = mx.fft.irfft(x, axis=0).transpose(1, 0)
for i in range(num_frames):
# get the position in the time-domain signal to add the frame
start = i * hop_length
window_indicies = start + mx.arange(win_length)
# overlap-add the inverse transformed frame, scaled by the window
reconstructed = reconstructed.at[window_indicies].add(y[i] * window)
window_sum = window_sum.at[window_indicies].add(window)
# normalize by the sum of the window values
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)
return reconstructed
vs something like...
def istft(
x: mx.array, # (freq_bins, num_frames)
window: mx.array,
hop_length: int = 256,
win_length: int = 1024,
):
num_frames = x.shape[1]
total_len = (num_frames - 1) * hop_length + win_length
reconstructed = mx.zeros(total_len)
window_sum = mx.zeros(total_len)
# inverse FFT of each frame
y = mx.fft.irfft(x, axis=0).transpose(1, 0)
# create indicies for the overlap-add
frame_offsets = mx.arange(num_frames) * hop_length
time_indices = frame_offsets[:, None] + mx.arange(win_length)
indices = time_indices.flatten()
updates_reconstructed = (y * window).flatten()
updates_window = mx.tile(window, (num_frames,)).flatten()
# use scatter_add to perform the overlap-add
reconstructed = mx.scatter_add(reconstructed, indices, updates_reconstructed)
window_sum = mx.scatter_add(window_sum, indices, updates_window)
# normalize by the sum of the window values
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)
return reconstructed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment