Skip to content

Instantly share code, notes, and snippets.

@jeanmonet
Created January 7, 2025 19:46
Show Gist options
  • Save jeanmonet/12317cfc5bbde3c30d27e0d0e2c643c0 to your computer and use it in GitHub Desktop.
Save jeanmonet/12317cfc5bbde3c30d27e0d0e2c643c0 to your computer and use it in GitHub Desktop.
NMF Optimization as Dual Simplex formulation
"""
(Currently incorrect) attempt to implement NMF-Sinkhorn optimization outlined in:
Non-negative matrix factorization and deconvolution as dual simplex problem
Denis Kleverov, Ekaterina Aladyeva, Alexey Serdyukov, Maxim N. Artyomov
https://www.biorxiv.org/content/10.1101/2024.04.09.588652v1.full
R implementation by authors: https://github.com/artyomovlab/dualsimplex
"""
import numpy as np
from scipy import linalg
from scipy.optimize import nnls
from typing import Tuple, Optional, NamedTuple
class OptimizationState(NamedTuple):
"""Tracks optimization state including matrices and errors."""
X: np.ndarray # shape (K×N)
Omega: np.ndarray # shape (M×K)
D: np.ndarray # shape (K×K), diagonal
S: np.ndarray # shape (M×K)
R: np.ndarray # shape (N×K)
Sigma: np.ndarray # shape (K×K)
deconv_error: float = float("inf")
class SinkhornNMFOptimizer:
"""
NMF optimizer in Sinkhorn-transformed space.
- Focus on correctness of gradient shapes, hinge penalty, and D-updates.
- Pins:
* The first row of Omega => row-vector of 1/sqrt(M).
* The first column of X => column-vector of 1/sqrt(N).
"""
def __init__(
self,
n_components: int = 3,
max_iter: int = 1000,
tol: float = 1e-4,
lambda_param: float = 1.0,
beta_param: float = 1.0,
mu: float = 0.001,
nu: float = 0.001,
sinkhorn_max_iter: int = 20,
sinkhorn_start_check: int = 5,
sinkhorn_check_every: int = 3,
sinkhorn_tol: float = 1.490116e-8,
):
self.K = n_components
self.max_iter = max_iter
self.tol = tol
self.lambda_param = lambda_param
self.beta_param = beta_param
self.mu = mu
self.nu = nu
self.sinkhorn_max_iter = sinkhorn_max_iter
self.sinkhorn_start_check = sinkhorn_start_check
self.sinkhorn_check_every = sinkhorn_check_every
self.sinkhorn_tol = sinkhorn_tol
self.S = None # will store shape (M×K)
self.R = None # will store shape (N×K)
# ---------------------------------------------------------------
# 1) Sinkhorn normalization
# ---------------------------------------------------------------
def efficient_sinkhorn(self, V: np.ndarray) -> np.ndarray:
"""
Performs approximate row/column normalization (Sinkhorn) so that
rows and columns sum to ~1.
"""
V_curr = V.copy()
M, N = V_curr.shape
target_col_sums = np.ones(N)
for i in range(self.sinkhorn_max_iter):
# Row normalization
row_sums = np.sum(V_curr, axis=1, keepdims=True)
V_curr /= np.maximum(row_sums, 1e-12)
# Column normalization
col_sums = np.sum(V_curr, axis=0, keepdims=True)
V_curr /= np.maximum(col_sums, 1e-12)
# Check columns if they've converged
if (
i + 1 >= self.sinkhorn_start_check
and (i + 1 - self.sinkhorn_start_check) % self.sinkhorn_check_every == 0
):
if np.allclose(col_sums.flatten(), target_col_sums, atol=self.sinkhorn_tol):
break
return V_curr
# ---------------------------------------------------------------
# 2) SVD with enforced top singular vectors
# ---------------------------------------------------------------
def compute_svd(self, V_ss: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute truncated SVD => V_ss ~ U Sigma Vh, then define:
S = U[:, :K] (M×K)
Sigma = diag(s[:K]) (K×K)
R = Vh[:K, :].T => shape (N×K)
Then enforce:
S[:, 0] = 1/sqrt(M) # first left singular vector
R[:, 0] = 1/sqrt(N) # first right singular vector
"""
M, N = V_ss.shape
U, s, Vh = linalg.svd(V_ss, full_matrices=False)
S = U[:, :self.K] # (M×K)
Sigma = np.diag(s[:self.K]) # (K×K)
Rfull = Vh[:self.K, :] # (K×N)
R = Rfull.T # (N×K)
# Force the first left singular vector to be 1/sqrt(M)
S[:, 0] = 1.0 / np.sqrt(M)
# Force the first right singular vector to be 1/sqrt(N)
R[:, 0] = 1.0 / np.sqrt(N)
return S, Sigma, R
# ---------------------------------------------------------------
# 3) Hinge derivatives
# ---------------------------------------------------------------
def hinge_deriv_for_X(self, XR: np.ndarray, R: np.ndarray) -> np.ndarray:
"""
Hinge penalty: for each negative entry in (X @ R), add derivative w.r.t. X.
- X is (K×N), R is (N×K), so (X @ R) is (K×K).
- If (X @ R)[l, j] < 0 => grad_X[l, :] += -R[:, j].
"""
K, K2 = XR.shape # XR is (K×K)
N = R.shape[0]
grad = np.zeros((K, N), dtype=XR.dtype)
for l in range(K):
for j in range(K2):
if XR[l, j] < 0:
grad[l, :] += -R[:, j]
return grad
def hinge_deriv_for_Omega(self, STOmega: np.ndarray, S: np.ndarray) -> np.ndarray:
"""
Hinge penalty: for each negative entry in (S^T @ Omega), add derivative w.r.t. Omega.
- S is (M×K), so S^T is (K×M). Omega is (M×K).
- (S^T @ Omega) => shape (K×K).
- If (S^T @ Omega)[k1, k2] < 0 => grad_Omega[:, k2] += -S[:, k1].
"""
K, K2 = STOmega.shape # (K×K)
M = S.shape[0]
grad = np.zeros((M, K2), dtype=STOmega.dtype)
for k1 in range(K):
for k2 in range(K2):
if STOmega[k1, k2] < 0:
grad[:, k2] += -S[:, k1]
return grad
# ---------------------------------------------------------------
# 4) Compute full gradients w.r.t X and Omega
# ---------------------------------------------------------------
def compute_gradients(
self,
V_ss: np.ndarray, # shape (M×N)
state: OptimizationState
) -> Tuple[np.ndarray, np.ndarray]:
"""
- Residual = V_ss - (Omega * D * X), shape (M×N).
- grad_X = -2 * D * Omega^T * Residual + lambda * hingeDeriv( X@R )
- grad_Omega = -2 * Residual * X^T * D + beta * hingeDeriv( S^T@Omega )
"""
D_diag = np.diag(np.diag(state.D)) # shape (K×K)
Recon = state.Omega @ D_diag @ state.X # shape (M×N)
Residual = V_ss - Recon # shape (M×N)
# Gradient wrt X => shape (K×N)
# -2 * D * (Omega^T @ Residual)
tmp_OM = state.Omega.T @ Residual # (K×N)
grad_X = -2.0 * (D_diag @ tmp_OM) # (K×N)
# Add hinge for negativity in (X @ R)
XR = state.X @ state.R # (K×K)
hinge_X = self.hinge_deriv_for_X(XR, state.R) # (K×N)
grad_X += self.lambda_param * hinge_X
# Pin first column of X => zero out gradient for that column
grad_X[:, 0] = 0.0
# Gradient wrt Omega => shape (M×K)
# -2 * (Residual @ X^T @ D)
tmp_RX = Residual @ state.X.T # (M×K)
grad_Omega = -2.0 * (tmp_RX @ D_diag) # (M×K)
# Add hinge for negativity in (S^T @ Omega)
STOmega = state.S.T @ state.Omega # (K×K)
hinge_Omega = self.hinge_deriv_for_Omega(STOmega, state.S) # (M×K)
grad_Omega += self.beta_param * hinge_Omega
# Pin first row of Omega => zero out gradient for that row
grad_Omega[0, :] = 0.0
return grad_X, grad_Omega
# ---------------------------------------------------------------
# 5) Update D by NNLS in projected (K×K) space
# ---------------------------------------------------------------
def update_D_matrix(
self,
V_proj: np.ndarray,
X: np.ndarray,
Omega: np.ndarray,
M: int
) -> np.ndarray:
"""
Solve NNLS in the (K×K)-projected space:
min_{d>=0} || vec(V_proj) - Q(X_proj, Omega_proj)*d ||^2
subject to trace(D)=M. (D is diagonal, size K×K)
"""
K = self.K
vec_target = V_proj.flatten() # shape (K^2,)
# Omega_proj = S^T @ Omega => shape (K×K)
# X_proj = X @ R => shape (K×K)
Omega_proj = self.S.T @ Omega
X_proj = X @ self.R
# Build Q => shape (K^2 × K) from Kronecker products
Q_mat = np.zeros((K*K, K), dtype=V_proj.dtype)
for c in range(K):
# c-th column is vec( Omega_proj[:,c] * X_proj[c,:]^T )
# i.e. np.kron( Omega_proj[:, c], X_proj[c, :] )
col_vec = np.kron(Omega_proj[:, c], X_proj[c, :])
Q_mat[:, c] = col_vec
d_vec, _ = nnls(Q_mat, vec_target)
# Scale so trace(D)=M
d_sum = np.sum(d_vec)
if d_sum > 1e-12:
d_vec *= (M / d_sum)
return np.diag(d_vec)
# ---------------------------------------------------------------
# 6) Adaptive learning rates for gradient updates
# ---------------------------------------------------------------
def compute_adaptive_learning_rates(
self,
grad_X: np.ndarray,
grad_Omega: np.ndarray
) -> Tuple[float, float]:
"""
Simple scheme: chi = mu / ||grad_X||, phi = nu / ||grad_Omega||.
"""
gx_norm = np.linalg.norm(grad_X, ord=2)
go_norm = np.linalg.norm(grad_Omega, ord=2)
chi = self.mu / (gx_norm + 1e-12)
phi = self.nu / (go_norm + 1e-12)
return chi, phi
# ---------------------------------------------------------------
# 7) Full fit
# ---------------------------------------------------------------
def fit_transform(self, V: np.ndarray, random_state: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Main entry: factor V ~ Omega * D * X, shapes:
- V: (M×N)
- Omega: (M×K)
- D: (K×K), diagonal
- X: (K×N)
"""
if random_state is not None:
np.random.seed(random_state)
M, N = V.shape
# 7a) Sinkhorn
self.V_ss = self.efficient_sinkhorn(V)
# 7b) SVD for the projected space
S, Sigma, R = self.compute_svd(self.V_ss)
self.S = S # (M×K)
self.R = R # (N×K)
# 7c) Initialize Omega, X, D
Omega_init = np.random.rand(M, self.K) * 1e-2
X_init = np.random.rand(self.K, N) * 1e-2
# Pin the first row of Omega => row-vector of 1/sqrt(M)
Omega_init[0, :] = 1.0 / np.sqrt(M)
# Pin the first column of X => column-vector of 1/sqrt(N)
X_init[:, 0] = 1.0 / np.sqrt(N)
D_init = np.eye(self.K)
D_init *= (M / self.K)
state = OptimizationState(
X=X_init,
Omega=Omega_init,
D=D_init,
S=S,
R=R,
Sigma=Sigma,
deconv_error=float("inf")
)
prev_error = float("inf")
# V_proj = S^T @ V_ss @ R => shape (K×K)
V_proj = S.T @ self.V_ss @ R
for iteration in range(self.max_iter):
# -- compute gradients
grad_X, grad_Omega = self.compute_gradients(self.V_ss, state)
# -- adaptive LR
chi, phi = self.compute_adaptive_learning_rates(grad_X, grad_Omega)
# -- gradient steps
X_new = state.X - chi * grad_X
np.maximum(X_new, 0.0, out=X_new) # enforce non-negativity
# re-pin first column
X_new[:, 0] = 1.0 / np.sqrt(N)
Omega_new = state.Omega - phi * grad_Omega
np.maximum(Omega_new, 0.0, out=Omega_new)
# re-pin first row
Omega_new[0, :] = 1.0 / np.sqrt(M)
# -- update D via NNLS in (K×K) subspace
D_new = self.update_D_matrix(V_proj, X_new, Omega_new, M)
# -- measure error in projected space
Recon_kxk = Omega_new @ D_new @ X_new # (M×N)
Recon_proj = self.S.T @ Recon_kxk @ self.R # (K×K)
deconv_error = np.linalg.norm(V_proj - Recon_proj, 'fro') ** 2
if abs(deconv_error - prev_error) < self.tol:
break
state = state._replace(
X=X_new,
Omega=Omega_new,
D=D_new,
deconv_error=deconv_error
)
prev_error = deconv_error
return state.Omega, state.X, state.D
def run_nmf_optimization(
V: np.ndarray,
n_components: int = 3,
**kwargs
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Convenience wrapper.
V: shape (M×N)
Returns Omega(M×K), X(K×N), D(K×K).
"""
optimizer = SinkhornNMFOptimizer(n_components=n_components, **kwargs)
return optimizer.fit_transform(V)
# --------------------- Example usage ---------------------
if __name__ == "__main__":
np.random.seed(42)
M, N = 100, 50
V = np.abs(np.random.randn(M, N))
optimizer = SinkhornNMFOptimizer(
n_components=3,
lambda_param=1.0,
beta_param=1.0,
mu=0.001,
nu=0.001,
max_iter=200
)
W, H, D = optimizer.fit_transform(V)
V_ss_reconstructed = W @ D @ H
err = np.sum((optimizer.V_ss - V_ss_reconstructed)**2)
print("Final V_ss reconstruction error:", err)
print("trace(D) =", np.trace(D), "(should be ~ M =", M, ")")
print("Min(W) =", W.min(), "Min(H) =", H.min())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment