Skip to content

Instantly share code, notes, and snippets.

@alexarmbr
Created April 23, 2025 17:21
Show Gist options
  • Save alexarmbr/decc803be417b2639889933aa9c17ce0 to your computer and use it in GitHub Desktop.
Save alexarmbr/decc803be417b2639889933aa9c17ce0 to your computer and use it in GitHub Desktop.
A decorator that implements the algorithm explained in "From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers"
import functools
import torch
import math
def taylor_seer_approximation(WARMUP_STEPS=1, SKIP_INTERVAL_STEPS=1, compute_step_map=None, n_derivatives = 2):
"""
A decorator that approximates the forward pass of an nn.Module to reduce computation.
Args:
warmup: Number of steps to compute the actual forward pass before starting approximation
skip_interval: After warmup, compute the actual forward pass every 'skip_interval' steps
compute_step_map: A list of booleans that indicate whether to compute the actual forward pass for each step or not
if compute_step_map is passed, ignore warmup and skip_interval
n_derivatives: The number of derivatives to approximate
Returns:
A decorator function that can be applied to an nn.Module
"""
# make sure the warmup and skip interval are at least 1
WARMUP_STEPS = max(int(WARMUP_STEPS), 1)
SKIP_INTERVAL_STEPS = max(int(SKIP_INTERVAL_STEPS), 1)
# 'order' of the taylor approximation is the value itself, plus however many
# derivatives we are approximating
ORDER = n_derivatives + 1
def decorator(cls):
original_init = cls.__init__
original_forward = cls.forward
@functools.wraps(cls.__init__)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.state = {
'dY_prev': [None] * ORDER,
'dY_current': [None] * ORDER,
}
self.step_count = -1
self.finite_difference_window = 0
def _should_compute_full(step):
if compute_step_map is not None:
return compute_step_map[step % len(compute_step_map)]
else:
# we compute the actual forward pass for the first warmup steps
# and then we compute the forward pass every skip_interval after that
if (step <= WARMUP_STEPS or
(step > WARMUP_STEPS and (step - WARMUP_STEPS) % SKIP_INTERVAL_STEPS == 0)):
return True
else:
return False
def _approximate_derivative(Y, dY_prev, elapsed_steps):
"""
Approximate the derivative of Y using the previous derivatives
Args:
Y: current value of the feature, i.e. Y=f(X) where f could be a transformer or linear layer
dY_prev: the value of the derivative of Y t steps ago
elapsed_steps: number of steps between Y and dY_prev
"""
dY_current = [None] * ORDER
dY_current[0] = Y
for i in range(1, ORDER):
if dY_prev[i-1] is not None:
# estimate current derivative using finite difference method
# equation (7) from the paper, along with scaling factor N^i from equation (9)
# the N^i is done implicitly by elapsed_steps being inside the loop
# https://github.com/Shenyi-Z/TaylorSeer/issues/11
dY_current[i] = (dY_current[i-1] - dY_prev[i-1]) / elapsed_steps
return dY_current
def _approximate_value(dY_current, elapsed_steps):
"""
Approximate the current value of Y using our current estimate of the derivative
and the # of timesteps that have passed since the derivative was computed
Args:
dY_current: the value of the derivatives of Y
elapsed_steps: number of steps between Y and dY_prev
"""
# taylor series approximation
return sum([
(1 / math.factorial(i)) * dY_current[i] * (elapsed_steps**i)
for i in range(ORDER)
])
@functools.wraps(cls.forward)
def new_forward(self, *args, **kwargs):
self.step_count += 1
self.finite_difference_window += 1
if _should_compute_full(self.step_count):
# compute actual forward pass
Y = original_forward(self, *args, **kwargs)
assert isinstance(Y, torch.Tensor)
self.state['dY_prev'] = self.state['dY_current']
# calculate and update derivative based on present model output and previous derivatives
self.state['dY_current'] = _approximate_derivative(Y, self.state['dY_prev'], self.finite_difference_window)
# reset the finite difference window
self.finite_difference_window = 0
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}")
return Y
# approximate the value of the forward pass using the derivative computed 'finite_difference_window' steps ago
else:
assert all([i is not None for i in self.state['dY_current']])
print(f"step {self.step_count}, compute_full: {_should_compute_full(self.step_count)} state: {self.state}")
return _approximate_value(self.state['dY_current'], self.finite_difference_window)
# Replace methods
cls.__init__ = new_init
cls.forward = new_forward
return cls
return decorator
if __name__ == "__main__":
# compute the actual forward pass for the first 3 steps, and then approximate for the next 7 steps
compute_step_map = [
True, # step 1
True, # step 2
True, # step 3
False, # step 4
False, # step 5
False, # step 6
False, # step 7
False, # step 8
False, # step 9
False, # step 10
]
@taylor_seer_approximation(compute_step_map=compute_step_map, n_derivatives=2)
class X_Squared(torch.nn.Module):
def forward(self, x):
return x**2
x_squared = X_Squared()
for i in range(10):
X = torch.tensor(i, dtype=torch.float32)
Y = x_squared(X)
print(f"step {i}, X: {X}, Y: {Y}")
print("-------------------------")
@alexarmbr
Copy link
Author

thanks for taking a look!! I am trying to reproduce results from the reference implementation which does all of the bookkeeping using the # of the diffusion step, which goes from 0->28 in steps of 1. So I think this part is correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment