Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
crowsonkb / syre.py
Last active July 18, 2025 18:18
Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495).
"""Implementation of the syre weight decay algorithm from "Remove Symmetries to Control Model
Expressivity and Improve Optimization" (https://arxiv.org/abs/2408.15495)."""
import math
import torch
from torch import optim
import triton
import triton.language as tl
@crowsonkb
crowsonkb / energy_matching.py
Last active July 8, 2025 10:28
The energy matching loss
"""The energy matching loss.
Energy matching regresses an energy function to match a target energy function at the points in the
dataset. ("Energy" refers to an unnormalized negative log probability: for a sequence model it is
the sum of the cross-entropy losses of a sequence's completion tokens plus some constant. Two energy
functions are considered "the same" by the energy matching loss if they differ by an arbitrary
constant.) This is useful for on-policy reinforcement learning or off-policy preference tuning of
sequence models, where the target energies are:
[the sequences' energies under the reference model] - [the sequences' rewards] / beta,
@crowsonkb
crowsonkb / lfq.py
Last active June 18, 2025 01:51
Lookup Free Quantization (LFQ) for PyTorch.
"""Lookup Free Quantization (LFQ) for PyTorch."""
from dataclasses import dataclass
from itertools import product
import math
from typing import Optional
import torch
from torch import distributed as dist, nn
from torch.distributed import nn as dnn
@crowsonkb
crowsonkb / plackett_luce.py
Last active November 5, 2024 21:33
Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.
"""Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""
from itertools import chain
from typing import List, Optional, Tuple
import torch
def plackett_luce_loss(
scores: torch.Tensor,
@crowsonkb
crowsonkb / kld_noise_generator.py
Created October 25, 2024 17:15
Generates a smoothly varying standard normal time series.
"""Generates a smoothly varying standard normal time series."""
import numpy as np
import scipy.linalg
import torch
class KLDNoiseGenerator(torch.nn.Module):
"""Generates a smoothly varying standard normal time series.
@crowsonkb
crowsonkb / ring_attn.py
Created October 10, 2024 16:19
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
@crowsonkb
crowsonkb / mos.py
Last active April 11, 2024 21:23
Mixture of Softmaxes
"""Mixture of Softmaxes"""
import torch
from torch.nn import functional as F
class MixtureOfSoftmaxes(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p):
with torch.cuda.amp.autocast(enabled=False):
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm."""
from dataclasses import dataclass
import warnings
import torch
from torch import nn
try:
@crowsonkb
crowsonkb / spo_loss.py
Last active June 10, 2024 15:38
Scalar Preference Optimization
"""Scalar Preference Optimization."""
import torch
from torch.nn import functional as F
def logp_completion(logits, tokens, mask):
"""Compute the log probabilities of completions given their prompts.
Args:
@crowsonkb
crowsonkb / reinforce.py
Last active June 30, 2023 19:12
REINFORCE with exponential moving average baseline
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098)."""
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Union