Created
August 20, 2024 11:08
-
-
Save el-hult/d35566f5f941119341a5f7439f657f91 to your computer and use it in GitHub Desktop.
Jax program that computes the Legendre approximation for the Entropy of a Gaussian Mixture Model, according to Caleb Dahlke and Jason Pacheco https://papers.nips.cc/paper_files/paper/2023/hash/ee860a9fa65a55a335754c557a5211de-Abstract-Conference.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import functools | |
def gmm_max(logpi, mu, cov): | |
"""Return an upper bound the GMM density | |
logpi: (K,)-array of the log-weights | |
mu: (K, D)-array of the component means | |
cov: (K,D,D)-array of the component covarainces | |
if p(x) = sum_i pi_i N(x|mu_i, cov_i) then | |
max(p(x)) <= sum_i pi_i max(N(x|mu_i, cov_i)) | |
= sum_i pi_i N(0|0, cov_i) | |
""" | |
logpdfmaxs = jax.vmap( | |
lambda m, c: jax.scipy.stats.multivariate_normal.logpdf(m, m, cov=c) | |
)(mu, cov) | |
res = jnp.sum(jnp.exp(logpdfmaxs + logpi)) | |
return res | |
@functools.partial(jax.jit, static_argnames=["K", "c"]) | |
def partitions_(k, K, c): | |
""" | |
PRECONDITION: c = combinations(k+K-1, k, exact=True) | |
TODO come up with a way to NOT have to pass c, while still allowing JIT | |
""" | |
def _inner(k, buffer, pos, carry): | |
obuffer, opos = carry | |
if pos == K - 1: | |
updated_buffer = buffer.at[pos].set(k) | |
return (obuffer.at[opos, :].set(updated_buffer), opos + 1) | |
else: | |
return jax.lax.fori_loop( | |
0, | |
k + 1, | |
lambda i, carry: _inner( | |
k - i, buffer.at[pos].set(i), pos + 1, carry=carry | |
), | |
carry, | |
) | |
buffer = jnp.zeros(K, dtype=jnp.int32) | |
carry = (jnp.zeros((c, K), dtype=jnp.int32), 0) | |
obuffer, opos = _inner(k=k, pos=0, buffer=buffer, carry=carry) | |
return obuffer | |
def comb(n, k): | |
"""Compute n choose k via gamma function""" | |
return jnp.exp( | |
jax.scipy.special.gammaln(n + 1) | |
- jax.scipy.special.gammaln(k + 1) | |
- jax.scipy.special.gammaln(n - k + 1) | |
) | |
def logfactorial(n): | |
return jax.scipy.special.gammaln(n + 1) | |
def log_multinomial_coeff(ks): | |
"""Multinomial coefficient approximated with log-gamma-function""" | |
m = sum(ks) | |
return jax.scipy.special.gammaln(m + 1) - sum( | |
jax.scipy.special.gammaln(k + 1) for k in ks | |
) | |
def expectation_of_powers(logpi, mu, cov, k, K, n_partitions): | |
"""Compute the closed form expectation of powers of a GMM. | |
Equation 11 in the cited paper | |
I would *really* like to not pass `K` and `n_partitions` as arguments, but I don't know how to do that with JIT | |
PRECONDITION: | |
sum(exp(logpi)) == 1 | |
n_partitions == int(round(comb(k+K-1, k)) | |
""" | |
assert logpi.ndim == 1 | |
assert len(mu) == len(cov) == K | |
inv_covs = [jnp.linalg.inv(c) for c in cov] | |
# Iterate over j1, ..., jK >= 0 where j1 + ... + jK = k | |
result = 0.0 | |
partitions = partitions_(k, K, n_partitions) | |
for j in partitions: | |
# Sum over i | |
log_inner_summands = jnp.zeros(K) | |
for i in range(K): | |
cov_combined = jnp.linalg.inv( | |
inv_covs[i] + sum(j[t] * inv_covs[t] for t in range(K)) | |
) | |
mu_combined = cov_combined @ ( | |
inv_covs[i] @ mu[i] + sum(j[t] * inv_covs[t] @ mu[t] for t in range(K)) | |
) | |
# Gaussian ratio | |
log_N_ratio = jax.scipy.stats.multivariate_normal.logpdf( | |
0, mean=mu[i], cov=cov[i] | |
) - jax.scipy.stats.multivariate_normal.logpdf( | |
0, mean=mu_combined, cov=cov_combined | |
) | |
# Product term | |
log_product_term = sum( | |
j[t] | |
* ( | |
logpi[t] | |
+ jax.scipy.stats.multivariate_normal.logpdf( | |
0, mean=mu[t], cov=cov[t] | |
) | |
) | |
for t in range(K) | |
) | |
# Weighted sum | |
s = logpi[i] + log_N_ratio + log_product_term | |
log_inner_summands = log_inner_summands.at[i].set(s) | |
result += jnp.exp( | |
log_multinomial_coeff(j) + jax.nn.logsumexp(log_inner_summands) | |
) | |
return result | |
@functools.partial(jax.jit, static_argnames=["order", "K", "combss"]) | |
def gmm_entropy_legendre_jax_(logpi, mu, cov, order, K, combss): | |
"""Entropy of a GMM using legendre polynomial approximation | |
Formula due to Dahlke and Pacheko | |
See equation 16 in the cited paper | |
PRECONDITIONS: | |
len(logpi) == K | |
combss = [combinations(k+K-1, k, exact=True) for k in range(order+1)] | |
""" | |
assert logpi.ndim == 1 | |
assert mu.ndim == 2 | |
d = mu.shape[1] | |
assert mu.shape == (K, d), f"mu.shape={mu.shape}, K={K}, d={d}" | |
assert cov.ndim == 3 | |
assert cov.shape == (K, d, d) | |
N = order # alias | |
a = ( | |
gmm_max(logpi, mu, cov) + 1e-3 | |
) # a cutoff that must be larger than the maximum density | |
loga = jnp.log(a) | |
Ep_px = [ | |
expectation_of_powers(logpi, mu, cov, k, K, combs) | |
for k, combs in zip(range(0, N + 1), combss) | |
] | |
summands = jnp.zeros(N + 1) | |
for n in range(0, N + 1): | |
coeffs_sum = jax.lax.fori_loop( | |
0, | |
n + 1, | |
lambda j, acc: acc | |
+ (-1) ** (n + j) | |
* ((j + 1) * loga - 1) | |
* jnp.exp( | |
logfactorial(n + j) - logfactorial(n - j) - 2 * logfactorial(j + 1) | |
), | |
0.0, | |
) | |
coeffs2_sum = jnp.sum( | |
jnp.array( | |
[ | |
(-1) ** (n + j) | |
* jnp.exp( | |
logfactorial(n + j) - logfactorial(n - j) - 2 * logfactorial(j) | |
) | |
* Ep_px[j] | |
/ a**j | |
for j in range(n + 1) | |
] | |
) | |
) | |
summands = summands.at[n].set((2 * n + 1) * coeffs_sum * coeffs2_sum) | |
return -jnp.sum(summands) | |
def gmm_entropy_legendre_jax(logpi, mu, cov, order): | |
"""Entropy of a GMM using legendre polynomial approximation | |
Formula due to Dahlke and Pacheko | |
See equation 16 in the cited paper (or the corresponding formula in proof of theorem 4.5) | |
""" | |
assert mu.ndim == 2 | |
K = mu.shape[0] | |
D = mu.shape[1] | |
assert logpi.shape == (K,), f"logpi.shape={logpi.shape}, K={K}" | |
assert cov.shape == (K, D, D), f"cov.shape={cov.shape}, K={K}, D={D}" | |
combss = tuple(int(round(comb(k + K - 1, k))) for k in range(order + 1)) | |
return gmm_entropy_legendre_jax_(logpi, mu, cov, order=order, K=K, combss=combss) | |
if __name__ == "__main__": | |
order = 8 | |
def mvn_entropy_exact(cov): | |
"""The entropy of a single Gaussian with covariance matrix `cov`.""" | |
assert cov.ndim == 2 | |
k = cov.shape[0] | |
assert cov.shape == (k, k) | |
return (k / 2) * (1 + jnp.log(2 * jnp.pi)) + 0.5 * jnp.log(jnp.linalg.det(cov)) | |
logpi = jnp.array([0.0]) | |
mu = jnp.array([[10.0, -20]]) | |
cov = jnp.array([[[2, 0], [0, 7]]]) | |
a = mvn_entropy_exact(cov[0]) | |
print(a) | |
a = gmm_entropy_legendre_jax(logpi, mu, cov, order=order) | |
print(a) | |
a = mvn_entropy_exact(jnp.array([[1.0]])) | |
print(a) | |
a = gmm_entropy_legendre_jax( | |
jnp.array([0.0]), jnp.array([[0.0]]), jnp.array([[[1.0]]]), order=order | |
) | |
print(a) | |
a = gmm_entropy_legendre_jax( | |
jnp.log(jnp.ones(2) / 2), | |
jnp.array([[0.0], [0.0]]), | |
jnp.array([[[1.0]], [[1.0]]]), | |
order=order, | |
) | |
print(a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment