Last active
August 11, 2025 16:14
-
-
Save vene/7275186241e4d9472bb4144aabf70249 to your computer and use it in GitHub Desktop.
Randomized pivoted Cholesky in pytorch: low-memory versoin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Vlad Niculae <[email protected]> | |
# License: MIT | |
# Memory-efficient randomized pivoted Cholesky factorization | |
# Algorithm 2.3 from Epperly et al, https://arxiv.org/abs/2410.03969 | |
import torch | |
torch.set_printoptions(precision=2, linewidth=120, sci_mode=False) | |
class LazyKernelMatrix: | |
def __init__(self, X, length_scale, dtype=None): | |
n, d = X.shape | |
self.X = X | |
self.length_scale = length_scale | |
self.normsq = (X ** 2).sum(dim=-1) | |
self.dtype = dtype if dtype is not None else X.dtype | |
self.device = X.device | |
self.shape = (n, n) | |
def diag(self): | |
return torch.ones(self.shape[0], device=self.device, dtype=self.dtype) | |
def __getitem__(self, *args): | |
rowix, colix = args[0] | |
Xrow = self.X[rowix].to(dtype=self.dtype) | |
Xcol = self.X[colix].to(dtype=self.dtype) | |
distsq = -Xrow @ Xcol.T | |
distsq += self.normsq[rowix].unsqueeze(-1) / 2 | |
distsq += self.normsq[colix].unsqueeze(-2) / 2 | |
return torch.exp(-distsq / self.length_scale ** 2) | |
def rejection(H, max_accept=None): | |
m = H.shape[0] | |
ix = [] | |
L = torch.zeros_like(H) | |
if torch.trace(H) < 0: # return empty, cannot proceed | |
return L[ix, ix], ix | |
u = torch.clip(torch.diag(H), min=0) | |
unif = torch.rand_like(u) * u | |
for j in range(m): | |
if unif[j] < H[j, j]: # accepted | |
ix.append(j) | |
L[j:, j] = H[j:, j] / torch.sqrt(H[j,j]) | |
H[j+1:, j+1:] -= torch.outer(L[j+1:, j], L[j+1:, j]) | |
if max_accept is not None: | |
ix = ix[:max_accept] | |
ix = torch.tensor(ix, dtype=torch.long, device=H.device) | |
L = L[ix][:, ix] | |
return L, ix | |
def rpcholesky(K, block_size=100, n_rounds=10, n_atoms=None, chunk_size=None, tolerance=1e-5): | |
n = K.shape[0] | |
chunk_size = block_size if chunk_size is None else chunk_size | |
n_atoms = min(n, block_size * n_rounds) if n_atoms is None else n_atoms | |
diag = K.diag() | |
initial_trace = diag.sum() | |
L = torch.zeros(n_atoms, n_atoms, dtype=K.dtype, device=K.device) | |
ix = torch.zeros(n_atoms, dtype=torch.long, device=K.device) | |
# zeroth round: special case, since there are size-zero tensors. | |
cands = torch.multinomial(diag, num_samples=block_size, replacement=True) | |
H = K[cands, cands] | |
L_new, accepted = rejection(H, max_accept=n_atoms) | |
cands = cands[accepted] | |
k = cands.shape[0] | |
L[:k, :k] = L_new | |
ix[:k] = cands | |
for j in range(n // chunk_size): | |
start = chunk_size * j | |
cease = min(chunk_size * (j+1), n) | |
K_new = K[start:cease, cands] | |
Gt = torch.linalg.solve_triangular(L_new, K_new.T, upper=False) | |
diag[start:cease] -= (Gt**2).sum(dim=0) | |
diag[start:cease] = torch.clip(diag[start:cease], min=0) | |
# rounds 1 up to n_rounds | |
for it in range(1, n_rounds): | |
print(diag.sum().item(), tolerance * initial_trace.item()) | |
if diag.sum() <= tolerance * initial_trace: | |
print("Variance accounted for up to numerical precision, stopping.") | |
break | |
# sample candidates | |
cands = torch.multinomial(diag, num_samples=block_size, replacement=True) | |
K_old_new = K[ix[:k], cands] | |
B = torch.linalg.solve_triangular(L[:k, :k], K_old_new, upper=False).T | |
H = K[cands, cands] - B @ B.T | |
L_new, accepted = rejection(H, max_accept=n_atoms-k) | |
# assert torch.isfinite(L_new).all() | |
if len(accepted) == 0: | |
print("No more candidates could be added") | |
break | |
cands = cands[accepted] | |
B = B[accepted] | |
m = cands.shape[0] | |
# update cholesky factor and diagonal | |
K_old_new = torch.linalg.solve_triangular(L[:k, :k], B, upper=False, left=False).T | |
L[k:k+m, :k] = B | |
L[k:k+m, k:k+m] = L_new | |
# print(torch.linalg.norm(L[:k+m, :k+m] - torch.linalg.cholesky(K[ix[:k+m], ix[:k+m]]))) | |
for j in range(n // chunk_size): | |
start = chunk_size * j | |
cease = min(chunk_size * (j+1), n) | |
K_old = K[start:cease, ix[:k]] | |
K_new = K[start:cease, cands] - K_old @ K_old_new | |
Gt = torch.linalg.solve_triangular(L_new, K_new.T, upper=False) | |
diag[start:cease] -= (Gt**2).sum(dim=0) | |
diag[start:cease] = torch.clip(diag[start:cease], min=0) | |
ix[k:k+m] = cands | |
k += m | |
return ix[:k], L[:k, :k] | |
def main(): | |
X = torch.randn(100, 8) | |
K = LazyKernelMatrix(X, length_scale=3) | |
ix, L = rpcholesky(K, block_size=10, n_rounds=15, n_atoms=89) | |
print(len(ix)) | |
assert len(torch.unique(ix)) == len(ix) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment