Created
April 23, 2025 17:21
-
-
Save alexarmbr/decc803be417b2639889933aa9c17ce0 to your computer and use it in GitHub Desktop.
A decorator that implements the algorithm explained in "From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import torch | |
import math | |
def taylor_seer_approximation(WARMUP_STEPS=1, SKIP_INTERVAL_STEPS=1, compute_step_map=None, n_derivatives = 2): | |
""" | |
A decorator that approximates the forward pass of an nn.Module to reduce computation. | |
Args: | |
warmup: Number of steps to compute the actual forward pass before starting approximation | |
skip_interval: After warmup, compute the actual forward pass every 'skip_interval' steps | |
compute_step_map: A list of booleans that indicate whether to compute the actual forward pass for each step or not | |
if compute_step_map is passed, ignore warmup and skip_interval | |
n_derivatives: The number of derivatives to approximate | |
Returns: | |
A decorator function that can be applied to an nn.Module | |
""" | |
# make sure the warmup and skip interval are at least 1 | |
WARMUP_STEPS = max(int(WARMUP_STEPS), 1) | |
SKIP_INTERVAL_STEPS = max(int(SKIP_INTERVAL_STEPS), 1) | |
# 'order' of the taylor approximation is the value itself, plus however many | |
# derivatives we are approximating | |
ORDER = n_derivatives + 1 | |
def decorator(cls): | |
original_init = cls.__init__ | |
original_forward = cls.forward | |
@functools.wraps(cls.__init__) | |
def new_init(self, *args, **kwargs): | |
original_init(self, *args, **kwargs) | |
self.state = { | |
'dY_prev': [None] * ORDER, | |
'dY_current': [None] * ORDER, | |
} | |
self.step_count = -1 | |
self.finite_difference_window = 0 | |
def _should_compute_full(step): | |
if compute_step_map is not None: | |
return compute_step_map[step % len(compute_step_map)] | |
else: | |
# we compute the actual forward pass for the first warmup steps | |
# and then we compute the forward pass every skip_interval after that | |
if (step <= WARMUP_STEPS or | |
(step > WARMUP_STEPS and (step - WARMUP_STEPS) % SKIP_INTERVAL_STEPS == 0)): | |
return True | |
else: | |
return False | |
def _approximate_derivative(Y, dY_prev, elapsed_steps): | |
""" | |
Approximate the derivative of Y using the previous derivatives | |
Args: | |
Y: current value of the feature, i.e. Y=f(X) where f could be a transformer or linear layer | |
dY_prev: the value of the derivative of Y t steps ago | |
elapsed_steps: number of steps between Y and dY_prev | |
""" | |
dY_current = [None] * ORDER | |
dY_current[0] = Y | |
for i in range(1, ORDER): | |
if dY_prev[i-1] is not None: | |
# estimate current derivative using finite difference method | |
# equation (7) from the paper, along with scaling factor N^i from equation (9) | |
# the N^i is done implicitly by elapsed_steps being inside the loop | |
# https://github.com/Shenyi-Z/TaylorSeer/issues/11 | |
dY_current[i] = (dY_current[i-1] - dY_prev[i-1]) / elapsed_steps | |
return dY_current | |
def _approximate_value(dY_current, elapsed_steps): | |
""" | |
Approximate the current value of Y using our current estimate of the derivative | |
and the # of timesteps that have passed since the derivative was computed | |
Args: | |
dY_current: the value of the derivatives of Y | |
elapsed_steps: number of steps between Y and dY_prev | |
""" | |
# taylor series approximation | |
return sum([ | |
(1 / math.factorial(i)) * dY_current[i] * (elapsed_steps**i) | |
for i in range(ORDER) | |
]) | |
@functools.wraps(cls.forward) | |
def new_forward(self, *args, **kwargs): | |
self.step_count += 1 | |
self.finite_difference_window += 1 | |
if _should_compute_full(self.step_count): | |
# compute actual forward pass | |
Y = original_forward(self, *args, **kwargs) | |
assert isinstance(Y, torch.Tensor) | |
self.state['dY_prev'] = self.state['dY_current'] | |
# calculate and update derivative based on present model output and previous derivatives | |
self.state['dY_current'] = _approximate_derivative(Y, self.state['dY_prev'], self.finite_difference_window) | |
# reset the finite difference window | |
self.finite_difference_window = 0 | |
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}") | |
return Y | |
# approximate the value of the forward pass using the derivative computed 'finite_difference_window' steps ago | |
else: | |
assert all([i is not None for i in self.state['dY_current']]) | |
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}") | |
return _approximate_value(self.state['dY_current'], self.finite_difference_window) | |
# Replace methods | |
cls.__init__ = new_init | |
cls.forward = new_forward | |
return cls | |
return decorator | |
if __name__ == "__main__": | |
# compute the actual forward pass for the first 3 steps, and then approximate for the next 7 steps | |
compute_step_map = [ | |
True, # step 1 | |
True, # step 2 | |
True, # step 3 | |
False, # step 4 | |
False, # step 5 | |
False, # step 6 | |
False, # step 7 | |
False, # step 8 | |
False, # step 9 | |
False, # step 10 | |
] | |
@taylor_seer_approximation(compute_step_map=compute_step_map, n_derivatives=2) | |
class X_Squared(torch.nn.Module): | |
def forward(self, x): | |
return x**2 | |
x_squared = X_Squared() | |
for i in range(10): | |
X = torch.tensor(i, dtype=torch.float32) | |
Y = x_squared(X) | |
print(f"step {i}, X: {X}, Y: {Y}") | |
print("-------------------------") | |
thanks for taking a look!! I am trying to reproduce results from the reference implementation which does all of the bookkeeping using the # of the diffusion step, which goes from 0->28 in steps of 1. So I think this part is correct.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
this is great! so I think the root issue here stems from how timesteps actually work in a diffusion model. when we generate an image with Flux, we run through some low integer number of steps (~28).
but, each timestep corresponds to some actual value of t in the underlying function that we're modeling with diffusion steps; t generally goes from 1 -> 0 as timesteps go from 0 -> 28 (or whatever). This is mildly confusing, but think of the model as attempting to reverse the process of gradually adding noise to an image over time, such that if the model is some f and the image is x, f(x, t=1) is noise and f(x, t=0) is the image.
to further add complexity, these timesteps don't linearly map to the diffusion steps; generally the earlier diffusion steps have smaller delta t than the later steps. you can pull the actual numbers out from the scheduler you're using; on our default flux implementation,
timesteps[14]
(e.g. the timestep halfway through the generation) is equal to .75.all of this is to say,
self.finite_difference_window
should be calculated by storing something likeself.last_timestep_actually_run
whenever you calloriginal_forward
, and thenself.finite_difference_window = current_timestep - self.last_timestep_actually_run
.