Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created April 23, 2025 17:21
Show Gist options
  • Save alexarmbr/decc803be417b2639889933aa9c17ce0 to your computer and use it in GitHub Desktop.
Save alexarmbr/decc803be417b2639889933aa9c17ce0 to your computer and use it in GitHub Desktop.
A decorator that implements the algorithm explained in "From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers"
import functools
import torch
import math
def taylor_seer_approximation(WARMUP_STEPS=1, SKIP_INTERVAL_STEPS=1, compute_step_map=None, n_derivatives = 2):
"""
A decorator that approximates the forward pass of an nn.Module to reduce computation.
Args:
warmup: Number of steps to compute the actual forward pass before starting approximation
skip_interval: After warmup, compute the actual forward pass every 'skip_interval' steps
compute_step_map: A list of booleans that indicate whether to compute the actual forward pass for each step or not
if compute_step_map is passed, ignore warmup and skip_interval
n_derivatives: The number of derivatives to approximate
Returns:
A decorator function that can be applied to an nn.Module
"""
# make sure the warmup and skip interval are at least 1
WARMUP_STEPS = max(int(WARMUP_STEPS), 1)
SKIP_INTERVAL_STEPS = max(int(SKIP_INTERVAL_STEPS), 1)
# 'order' of the taylor approximation is the value itself, plus however many
# derivatives we are approximating
ORDER = n_derivatives + 1
def decorator(cls):
original_init = cls.__init__
original_forward = cls.forward
@functools.wraps(cls.__init__)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.state = {
'dY_prev': [None] * ORDER,
'dY_current': [None] * ORDER,
}
self.step_count = -1
self.finite_difference_window = 0
def _should_compute_full(step):
if compute_step_map is not None:
return compute_step_map[step % len(compute_step_map)]
else:
# we compute the actual forward pass for the first warmup steps
# and then we compute the forward pass every skip_interval after that
if (step <= WARMUP_STEPS or
(step > WARMUP_STEPS and (step - WARMUP_STEPS) % SKIP_INTERVAL_STEPS == 0)):
return True
else:
return False
def _approximate_derivative(Y, dY_prev, elapsed_steps):
"""
Approximate the derivative of Y using the previous derivatives
Args:
Y: current value of the feature, i.e. Y=f(X) where f could be a transformer or linear layer
dY_prev: the value of the derivative of Y t steps ago
elapsed_steps: number of steps between Y and dY_prev
"""
dY_current = [None] * ORDER
dY_current[0] = Y
for i in range(1, ORDER):
if dY_prev[i-1] is not None:
# estimate current derivative using finite difference method
# equation (7) from the paper, along with scaling factor N^i from equation (9)
# the N^i is done implicitly by elapsed_steps being inside the loop
# https://github.com/Shenyi-Z/TaylorSeer/issues/11
dY_current[i] = (dY_current[i-1] - dY_prev[i-1]) / elapsed_steps
return dY_current
def _approximate_value(dY_current, elapsed_steps):
"""
Approximate the current value of Y using our current estimate of the derivative
and the # of timesteps that have passed since the derivative was computed
Args:
dY_current: the value of the derivatives of Y
elapsed_steps: number of steps between Y and dY_prev
"""
# taylor series approximation
return sum([
(1 / math.factorial(i)) * dY_current[i] * (elapsed_steps**i)
for i in range(ORDER)
])
@functools.wraps(cls.forward)
def new_forward(self, *args, **kwargs):
self.step_count += 1
self.finite_difference_window += 1
if _should_compute_full(self.step_count):
# compute actual forward pass
Y = original_forward(self, *args, **kwargs)
assert isinstance(Y, torch.Tensor)
self.state['dY_prev'] = self.state['dY_current']
# calculate and update derivative based on present model output and previous derivatives
self.state['dY_current'] = _approximate_derivative(Y, self.state['dY_prev'], self.finite_difference_window)
# reset the finite difference window
self.finite_difference_window = 0
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}")
return Y
# approximate the value of the forward pass using the derivative computed 'finite_difference_window' steps ago
else:
assert all([i is not None for i in self.state['dY_current']])
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}")
return _approximate_value(self.state['dY_current'], self.finite_difference_window)
# Replace methods
cls.__init__ = new_init
cls.forward = new_forward
return cls
return decorator
if __name__ == "__main__":
# compute the actual forward pass for the first 3 steps, and then approximate for the next 7 steps
compute_step_map = [
True, # step 1
True, # step 2
True, # step 3
False, # step 4
False, # step 5
False, # step 6
False, # step 7
False, # step 8
False, # step 9
False, # step 10
]
@taylor_seer_approximation(compute_step_map=compute_step_map, n_derivatives=2)
class X_Squared(torch.nn.Module):
def forward(self, x):
return x**2
x_squared = X_Squared()
for i in range(10):
X = torch.tensor(i, dtype=torch.float32)
Y = x_squared(X)
print(f"step {i}, X: {X}, Y: {Y}")
print("-------------------------")
@daanelson
Copy link

this is great! so I think the root issue here stems from how timesteps actually work in a diffusion model. when we generate an image with Flux, we run through some low integer number of steps (~28).

but, each timestep corresponds to some actual value of t in the underlying function that we're modeling with diffusion steps; t generally goes from 1 -> 0 as timesteps go from 0 -> 28 (or whatever). This is mildly confusing, but think of the model as attempting to reverse the process of gradually adding noise to an image over time, such that if the model is some f and the image is x, f(x, t=1) is noise and f(x, t=0) is the image.

to further add complexity, these timesteps don't linearly map to the diffusion steps; generally the earlier diffusion steps have smaller delta t than the later steps. you can pull the actual numbers out from the scheduler you're using; on our default flux implementation, timesteps[14] (e.g. the timestep halfway through the generation) is equal to .75.

all of this is to say, self.finite_difference_window should be calculated by storing something like self.last_timestep_actually_run whenever you call original_forward, and then self.finite_difference_window = current_timestep - self.last_timestep_actually_run.

@alexarmbr
Copy link
Author

thanks for taking a look!! I am trying to reproduce results from the reference implementation which does all of the bookkeeping using the # of the diffusion step, which goes from 0->28 in steps of 1. So I think this part is correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment