Skip to content

Instantly share code, notes, and snippets.

@vene
Last active August 11, 2025 16:14
Show Gist options
  • Save vene/7275186241e4d9472bb4144aabf70249 to your computer and use it in GitHub Desktop.
Save vene/7275186241e4d9472bb4144aabf70249 to your computer and use it in GitHub Desktop.
Randomized pivoted Cholesky in pytorch: low-memory versoin
# Author: Vlad Niculae <[email protected]>
# License: MIT
# Memory-efficient randomized pivoted Cholesky factorization
# Algorithm 2.3 from Epperly et al, https://arxiv.org/abs/2410.03969
import torch
torch.set_printoptions(precision=2, linewidth=120, sci_mode=False)
class LazyKernelMatrix:
def __init__(self, X, length_scale, dtype=None):
n, d = X.shape
self.X = X
self.length_scale = length_scale
self.normsq = (X ** 2).sum(dim=-1)
self.dtype = dtype if dtype is not None else X.dtype
self.device = X.device
self.shape = (n, n)
def diag(self):
return torch.ones(self.shape[0], device=self.device, dtype=self.dtype)
def __getitem__(self, *args):
rowix, colix = args[0]
Xrow = self.X[rowix].to(dtype=self.dtype)
Xcol = self.X[colix].to(dtype=self.dtype)
distsq = -Xrow @ Xcol.T
distsq += self.normsq[rowix].unsqueeze(-1) / 2
distsq += self.normsq[colix].unsqueeze(-2) / 2
return torch.exp(-distsq / self.length_scale ** 2)
def rejection(H, max_accept=None):
m = H.shape[0]
ix = []
L = torch.zeros_like(H)
if torch.trace(H) < 0: # return empty, cannot proceed
return L[ix, ix], ix
u = torch.clip(torch.diag(H), min=0)
unif = torch.rand_like(u) * u
for j in range(m):
if unif[j] < H[j, j]: # accepted
ix.append(j)
L[j:, j] = H[j:, j] / torch.sqrt(H[j,j])
H[j+1:, j+1:] -= torch.outer(L[j+1:, j], L[j+1:, j])
if max_accept is not None:
ix = ix[:max_accept]
ix = torch.tensor(ix, dtype=torch.long, device=H.device)
L = L[ix][:, ix]
return L, ix
def rpcholesky(K, block_size=100, n_rounds=10, n_atoms=None, chunk_size=None, tolerance=1e-5):
n = K.shape[0]
chunk_size = block_size if chunk_size is None else chunk_size
n_atoms = min(n, block_size * n_rounds) if n_atoms is None else n_atoms
diag = K.diag()
initial_trace = diag.sum()
L = torch.zeros(n_atoms, n_atoms, dtype=K.dtype, device=K.device)
ix = torch.zeros(n_atoms, dtype=torch.long, device=K.device)
# zeroth round: special case, since there are size-zero tensors.
cands = torch.multinomial(diag, num_samples=block_size, replacement=True)
H = K[cands, cands]
L_new, accepted = rejection(H, max_accept=n_atoms)
cands = cands[accepted]
k = cands.shape[0]
L[:k, :k] = L_new
ix[:k] = cands
for j in range(n // chunk_size):
start = chunk_size * j
cease = min(chunk_size * (j+1), n)
K_new = K[start:cease, cands]
Gt = torch.linalg.solve_triangular(L_new, K_new.T, upper=False)
diag[start:cease] -= (Gt**2).sum(dim=0)
diag[start:cease] = torch.clip(diag[start:cease], min=0)
# rounds 1 up to n_rounds
for it in range(1, n_rounds):
print(diag.sum().item(), tolerance * initial_trace.item())
if diag.sum() <= tolerance * initial_trace:
print("Variance accounted for up to numerical precision, stopping.")
break
# sample candidates
cands = torch.multinomial(diag, num_samples=block_size, replacement=True)
K_old_new = K[ix[:k], cands]
B = torch.linalg.solve_triangular(L[:k, :k], K_old_new, upper=False).T
H = K[cands, cands] - B @ B.T
L_new, accepted = rejection(H, max_accept=n_atoms-k)
# assert torch.isfinite(L_new).all()
if len(accepted) == 0:
print("No more candidates could be added")
break
cands = cands[accepted]
B = B[accepted]
m = cands.shape[0]
# update cholesky factor and diagonal
K_old_new = torch.linalg.solve_triangular(L[:k, :k], B, upper=False, left=False).T
L[k:k+m, :k] = B
L[k:k+m, k:k+m] = L_new
# print(torch.linalg.norm(L[:k+m, :k+m] - torch.linalg.cholesky(K[ix[:k+m], ix[:k+m]])))
for j in range(n // chunk_size):
start = chunk_size * j
cease = min(chunk_size * (j+1), n)
K_old = K[start:cease, ix[:k]]
K_new = K[start:cease, cands] - K_old @ K_old_new
Gt = torch.linalg.solve_triangular(L_new, K_new.T, upper=False)
diag[start:cease] -= (Gt**2).sum(dim=0)
diag[start:cease] = torch.clip(diag[start:cease], min=0)
ix[k:k+m] = cands
k += m
return ix[:k], L[:k, :k]
def main():
X = torch.randn(100, 8)
K = LazyKernelMatrix(X, length_scale=3)
ix, L = rpcholesky(K, block_size=10, n_rounds=15, n_atoms=89)
print(len(ix))
assert len(torch.unique(ix)) == len(ix)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment