Created
April 1, 2025 20:40
-
-
Save lucasnewman/d89fcd7b45c9d21af6081eaa733b77ac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlx.core as mx | |
def istft( | |
x: mx.array, # (freq_bins, num_frames) | |
window: mx.array, | |
hop_length: int = 256, | |
win_length: int = 1024, | |
): | |
num_frames = x.shape[1] | |
total_len = (num_frames - 1) * hop_length + win_length | |
reconstructed = mx.zeros(t) | |
window_sum = mx.zeros(t) | |
# inverse FFT of each frame | |
y = mx.fft.irfft(x, axis=0).transpose(1, 0) | |
for i in range(num_frames): | |
# get the position in the time-domain signal to add the frame | |
start = i * hop_length | |
window_indicies = start + mx.arange(win_length) | |
# overlap-add the inverse transformed frame, scaled by the window | |
reconstructed = reconstructed.at[window_indicies].add(y[i] * window) | |
window_sum = window_sum.at[window_indicies].add(window) | |
# normalize by the sum of the window values | |
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) | |
return reconstructed | |
vs something like... | |
def istft( | |
x: mx.array, # (freq_bins, num_frames) | |
window: mx.array, | |
hop_length: int = 256, | |
win_length: int = 1024, | |
): | |
num_frames = x.shape[1] | |
total_len = (num_frames - 1) * hop_length + win_length | |
reconstructed = mx.zeros(total_len) | |
window_sum = mx.zeros(total_len) | |
# inverse FFT of each frame | |
y = mx.fft.irfft(x, axis=0).transpose(1, 0) | |
# create indicies for the overlap-add | |
frame_offsets = mx.arange(num_frames) * hop_length | |
time_indices = frame_offsets[:, None] + mx.arange(win_length) | |
indices = time_indices.flatten() | |
updates_reconstructed = (y * window).flatten() | |
updates_window = mx.tile(window, (num_frames,)).flatten() | |
# use scatter_add to perform the overlap-add | |
reconstructed = mx.scatter_add(reconstructed, indices, updates_reconstructed) | |
window_sum = mx.scatter_add(window_sum, indices, updates_window) | |
# normalize by the sum of the window values | |
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) | |
return reconstructed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment