Created
January 7, 2025 19:46
-
-
Save jeanmonet/12317cfc5bbde3c30d27e0d0e2c643c0 to your computer and use it in GitHub Desktop.
NMF Optimization as Dual Simplex formulation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
(Currently incorrect) attempt to implement NMF-Sinkhorn optimization outlined in: | |
Non-negative matrix factorization and deconvolution as dual simplex problem | |
Denis Kleverov, Ekaterina Aladyeva, Alexey Serdyukov, Maxim N. Artyomov | |
https://www.biorxiv.org/content/10.1101/2024.04.09.588652v1.full | |
R implementation by authors: https://github.com/artyomovlab/dualsimplex | |
""" | |
import numpy as np | |
from scipy import linalg | |
from scipy.optimize import nnls | |
from typing import Tuple, Optional, NamedTuple | |
class OptimizationState(NamedTuple): | |
"""Tracks optimization state including matrices and errors.""" | |
X: np.ndarray # shape (K×N) | |
Omega: np.ndarray # shape (M×K) | |
D: np.ndarray # shape (K×K), diagonal | |
S: np.ndarray # shape (M×K) | |
R: np.ndarray # shape (N×K) | |
Sigma: np.ndarray # shape (K×K) | |
deconv_error: float = float("inf") | |
class SinkhornNMFOptimizer: | |
""" | |
NMF optimizer in Sinkhorn-transformed space. | |
- Focus on correctness of gradient shapes, hinge penalty, and D-updates. | |
- Pins: | |
* The first row of Omega => row-vector of 1/sqrt(M). | |
* The first column of X => column-vector of 1/sqrt(N). | |
""" | |
def __init__( | |
self, | |
n_components: int = 3, | |
max_iter: int = 1000, | |
tol: float = 1e-4, | |
lambda_param: float = 1.0, | |
beta_param: float = 1.0, | |
mu: float = 0.001, | |
nu: float = 0.001, | |
sinkhorn_max_iter: int = 20, | |
sinkhorn_start_check: int = 5, | |
sinkhorn_check_every: int = 3, | |
sinkhorn_tol: float = 1.490116e-8, | |
): | |
self.K = n_components | |
self.max_iter = max_iter | |
self.tol = tol | |
self.lambda_param = lambda_param | |
self.beta_param = beta_param | |
self.mu = mu | |
self.nu = nu | |
self.sinkhorn_max_iter = sinkhorn_max_iter | |
self.sinkhorn_start_check = sinkhorn_start_check | |
self.sinkhorn_check_every = sinkhorn_check_every | |
self.sinkhorn_tol = sinkhorn_tol | |
self.S = None # will store shape (M×K) | |
self.R = None # will store shape (N×K) | |
# --------------------------------------------------------------- | |
# 1) Sinkhorn normalization | |
# --------------------------------------------------------------- | |
def efficient_sinkhorn(self, V: np.ndarray) -> np.ndarray: | |
""" | |
Performs approximate row/column normalization (Sinkhorn) so that | |
rows and columns sum to ~1. | |
""" | |
V_curr = V.copy() | |
M, N = V_curr.shape | |
target_col_sums = np.ones(N) | |
for i in range(self.sinkhorn_max_iter): | |
# Row normalization | |
row_sums = np.sum(V_curr, axis=1, keepdims=True) | |
V_curr /= np.maximum(row_sums, 1e-12) | |
# Column normalization | |
col_sums = np.sum(V_curr, axis=0, keepdims=True) | |
V_curr /= np.maximum(col_sums, 1e-12) | |
# Check columns if they've converged | |
if ( | |
i + 1 >= self.sinkhorn_start_check | |
and (i + 1 - self.sinkhorn_start_check) % self.sinkhorn_check_every == 0 | |
): | |
if np.allclose(col_sums.flatten(), target_col_sums, atol=self.sinkhorn_tol): | |
break | |
return V_curr | |
# --------------------------------------------------------------- | |
# 2) SVD with enforced top singular vectors | |
# --------------------------------------------------------------- | |
def compute_svd(self, V_ss: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Compute truncated SVD => V_ss ~ U Sigma Vh, then define: | |
S = U[:, :K] (M×K) | |
Sigma = diag(s[:K]) (K×K) | |
R = Vh[:K, :].T => shape (N×K) | |
Then enforce: | |
S[:, 0] = 1/sqrt(M) # first left singular vector | |
R[:, 0] = 1/sqrt(N) # first right singular vector | |
""" | |
M, N = V_ss.shape | |
U, s, Vh = linalg.svd(V_ss, full_matrices=False) | |
S = U[:, :self.K] # (M×K) | |
Sigma = np.diag(s[:self.K]) # (K×K) | |
Rfull = Vh[:self.K, :] # (K×N) | |
R = Rfull.T # (N×K) | |
# Force the first left singular vector to be 1/sqrt(M) | |
S[:, 0] = 1.0 / np.sqrt(M) | |
# Force the first right singular vector to be 1/sqrt(N) | |
R[:, 0] = 1.0 / np.sqrt(N) | |
return S, Sigma, R | |
# --------------------------------------------------------------- | |
# 3) Hinge derivatives | |
# --------------------------------------------------------------- | |
def hinge_deriv_for_X(self, XR: np.ndarray, R: np.ndarray) -> np.ndarray: | |
""" | |
Hinge penalty: for each negative entry in (X @ R), add derivative w.r.t. X. | |
- X is (K×N), R is (N×K), so (X @ R) is (K×K). | |
- If (X @ R)[l, j] < 0 => grad_X[l, :] += -R[:, j]. | |
""" | |
K, K2 = XR.shape # XR is (K×K) | |
N = R.shape[0] | |
grad = np.zeros((K, N), dtype=XR.dtype) | |
for l in range(K): | |
for j in range(K2): | |
if XR[l, j] < 0: | |
grad[l, :] += -R[:, j] | |
return grad | |
def hinge_deriv_for_Omega(self, STOmega: np.ndarray, S: np.ndarray) -> np.ndarray: | |
""" | |
Hinge penalty: for each negative entry in (S^T @ Omega), add derivative w.r.t. Omega. | |
- S is (M×K), so S^T is (K×M). Omega is (M×K). | |
- (S^T @ Omega) => shape (K×K). | |
- If (S^T @ Omega)[k1, k2] < 0 => grad_Omega[:, k2] += -S[:, k1]. | |
""" | |
K, K2 = STOmega.shape # (K×K) | |
M = S.shape[0] | |
grad = np.zeros((M, K2), dtype=STOmega.dtype) | |
for k1 in range(K): | |
for k2 in range(K2): | |
if STOmega[k1, k2] < 0: | |
grad[:, k2] += -S[:, k1] | |
return grad | |
# --------------------------------------------------------------- | |
# 4) Compute full gradients w.r.t X and Omega | |
# --------------------------------------------------------------- | |
def compute_gradients( | |
self, | |
V_ss: np.ndarray, # shape (M×N) | |
state: OptimizationState | |
) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
- Residual = V_ss - (Omega * D * X), shape (M×N). | |
- grad_X = -2 * D * Omega^T * Residual + lambda * hingeDeriv( X@R ) | |
- grad_Omega = -2 * Residual * X^T * D + beta * hingeDeriv( S^T@Omega ) | |
""" | |
D_diag = np.diag(np.diag(state.D)) # shape (K×K) | |
Recon = state.Omega @ D_diag @ state.X # shape (M×N) | |
Residual = V_ss - Recon # shape (M×N) | |
# Gradient wrt X => shape (K×N) | |
# -2 * D * (Omega^T @ Residual) | |
tmp_OM = state.Omega.T @ Residual # (K×N) | |
grad_X = -2.0 * (D_diag @ tmp_OM) # (K×N) | |
# Add hinge for negativity in (X @ R) | |
XR = state.X @ state.R # (K×K) | |
hinge_X = self.hinge_deriv_for_X(XR, state.R) # (K×N) | |
grad_X += self.lambda_param * hinge_X | |
# Pin first column of X => zero out gradient for that column | |
grad_X[:, 0] = 0.0 | |
# Gradient wrt Omega => shape (M×K) | |
# -2 * (Residual @ X^T @ D) | |
tmp_RX = Residual @ state.X.T # (M×K) | |
grad_Omega = -2.0 * (tmp_RX @ D_diag) # (M×K) | |
# Add hinge for negativity in (S^T @ Omega) | |
STOmega = state.S.T @ state.Omega # (K×K) | |
hinge_Omega = self.hinge_deriv_for_Omega(STOmega, state.S) # (M×K) | |
grad_Omega += self.beta_param * hinge_Omega | |
# Pin first row of Omega => zero out gradient for that row | |
grad_Omega[0, :] = 0.0 | |
return grad_X, grad_Omega | |
# --------------------------------------------------------------- | |
# 5) Update D by NNLS in projected (K×K) space | |
# --------------------------------------------------------------- | |
def update_D_matrix( | |
self, | |
V_proj: np.ndarray, | |
X: np.ndarray, | |
Omega: np.ndarray, | |
M: int | |
) -> np.ndarray: | |
""" | |
Solve NNLS in the (K×K)-projected space: | |
min_{d>=0} || vec(V_proj) - Q(X_proj, Omega_proj)*d ||^2 | |
subject to trace(D)=M. (D is diagonal, size K×K) | |
""" | |
K = self.K | |
vec_target = V_proj.flatten() # shape (K^2,) | |
# Omega_proj = S^T @ Omega => shape (K×K) | |
# X_proj = X @ R => shape (K×K) | |
Omega_proj = self.S.T @ Omega | |
X_proj = X @ self.R | |
# Build Q => shape (K^2 × K) from Kronecker products | |
Q_mat = np.zeros((K*K, K), dtype=V_proj.dtype) | |
for c in range(K): | |
# c-th column is vec( Omega_proj[:,c] * X_proj[c,:]^T ) | |
# i.e. np.kron( Omega_proj[:, c], X_proj[c, :] ) | |
col_vec = np.kron(Omega_proj[:, c], X_proj[c, :]) | |
Q_mat[:, c] = col_vec | |
d_vec, _ = nnls(Q_mat, vec_target) | |
# Scale so trace(D)=M | |
d_sum = np.sum(d_vec) | |
if d_sum > 1e-12: | |
d_vec *= (M / d_sum) | |
return np.diag(d_vec) | |
# --------------------------------------------------------------- | |
# 6) Adaptive learning rates for gradient updates | |
# --------------------------------------------------------------- | |
def compute_adaptive_learning_rates( | |
self, | |
grad_X: np.ndarray, | |
grad_Omega: np.ndarray | |
) -> Tuple[float, float]: | |
""" | |
Simple scheme: chi = mu / ||grad_X||, phi = nu / ||grad_Omega||. | |
""" | |
gx_norm = np.linalg.norm(grad_X, ord=2) | |
go_norm = np.linalg.norm(grad_Omega, ord=2) | |
chi = self.mu / (gx_norm + 1e-12) | |
phi = self.nu / (go_norm + 1e-12) | |
return chi, phi | |
# --------------------------------------------------------------- | |
# 7) Full fit | |
# --------------------------------------------------------------- | |
def fit_transform(self, V: np.ndarray, random_state: Optional[int] = None | |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Main entry: factor V ~ Omega * D * X, shapes: | |
- V: (M×N) | |
- Omega: (M×K) | |
- D: (K×K), diagonal | |
- X: (K×N) | |
""" | |
if random_state is not None: | |
np.random.seed(random_state) | |
M, N = V.shape | |
# 7a) Sinkhorn | |
self.V_ss = self.efficient_sinkhorn(V) | |
# 7b) SVD for the projected space | |
S, Sigma, R = self.compute_svd(self.V_ss) | |
self.S = S # (M×K) | |
self.R = R # (N×K) | |
# 7c) Initialize Omega, X, D | |
Omega_init = np.random.rand(M, self.K) * 1e-2 | |
X_init = np.random.rand(self.K, N) * 1e-2 | |
# Pin the first row of Omega => row-vector of 1/sqrt(M) | |
Omega_init[0, :] = 1.0 / np.sqrt(M) | |
# Pin the first column of X => column-vector of 1/sqrt(N) | |
X_init[:, 0] = 1.0 / np.sqrt(N) | |
D_init = np.eye(self.K) | |
D_init *= (M / self.K) | |
state = OptimizationState( | |
X=X_init, | |
Omega=Omega_init, | |
D=D_init, | |
S=S, | |
R=R, | |
Sigma=Sigma, | |
deconv_error=float("inf") | |
) | |
prev_error = float("inf") | |
# V_proj = S^T @ V_ss @ R => shape (K×K) | |
V_proj = S.T @ self.V_ss @ R | |
for iteration in range(self.max_iter): | |
# -- compute gradients | |
grad_X, grad_Omega = self.compute_gradients(self.V_ss, state) | |
# -- adaptive LR | |
chi, phi = self.compute_adaptive_learning_rates(grad_X, grad_Omega) | |
# -- gradient steps | |
X_new = state.X - chi * grad_X | |
np.maximum(X_new, 0.0, out=X_new) # enforce non-negativity | |
# re-pin first column | |
X_new[:, 0] = 1.0 / np.sqrt(N) | |
Omega_new = state.Omega - phi * grad_Omega | |
np.maximum(Omega_new, 0.0, out=Omega_new) | |
# re-pin first row | |
Omega_new[0, :] = 1.0 / np.sqrt(M) | |
# -- update D via NNLS in (K×K) subspace | |
D_new = self.update_D_matrix(V_proj, X_new, Omega_new, M) | |
# -- measure error in projected space | |
Recon_kxk = Omega_new @ D_new @ X_new # (M×N) | |
Recon_proj = self.S.T @ Recon_kxk @ self.R # (K×K) | |
deconv_error = np.linalg.norm(V_proj - Recon_proj, 'fro') ** 2 | |
if abs(deconv_error - prev_error) < self.tol: | |
break | |
state = state._replace( | |
X=X_new, | |
Omega=Omega_new, | |
D=D_new, | |
deconv_error=deconv_error | |
) | |
prev_error = deconv_error | |
return state.Omega, state.X, state.D | |
def run_nmf_optimization( | |
V: np.ndarray, | |
n_components: int = 3, | |
**kwargs | |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Convenience wrapper. | |
V: shape (M×N) | |
Returns Omega(M×K), X(K×N), D(K×K). | |
""" | |
optimizer = SinkhornNMFOptimizer(n_components=n_components, **kwargs) | |
return optimizer.fit_transform(V) | |
# --------------------- Example usage --------------------- | |
if __name__ == "__main__": | |
np.random.seed(42) | |
M, N = 100, 50 | |
V = np.abs(np.random.randn(M, N)) | |
optimizer = SinkhornNMFOptimizer( | |
n_components=3, | |
lambda_param=1.0, | |
beta_param=1.0, | |
mu=0.001, | |
nu=0.001, | |
max_iter=200 | |
) | |
W, H, D = optimizer.fit_transform(V) | |
V_ss_reconstructed = W @ D @ H | |
err = np.sum((optimizer.V_ss - V_ss_reconstructed)**2) | |
print("Final V_ss reconstruction error:", err) | |
print("trace(D) =", np.trace(D), "(should be ~ M =", M, ")") | |
print("Min(W) =", W.min(), "Min(H) =", H.min()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment