Skip to content

Instantly share code, notes, and snippets.

@el-hult
Created August 20, 2024 11:08
Show Gist options
  • Save el-hult/d35566f5f941119341a5f7439f657f91 to your computer and use it in GitHub Desktop.
Save el-hult/d35566f5f941119341a5f7439f657f91 to your computer and use it in GitHub Desktop.
Jax program that computes the Legendre approximation for the Entropy of a Gaussian Mixture Model, according to Caleb Dahlke and Jason Pacheco https://papers.nips.cc/paper_files/paper/2023/hash/ee860a9fa65a55a335754c557a5211de-Abstract-Conference.html
import jax
import jax.numpy as jnp
import functools
def gmm_max(logpi, mu, cov):
"""Return an upper bound the GMM density
logpi: (K,)-array of the log-weights
mu: (K, D)-array of the component means
cov: (K,D,D)-array of the component covarainces
if p(x) = sum_i pi_i N(x|mu_i, cov_i) then
max(p(x)) <= sum_i pi_i max(N(x|mu_i, cov_i))
= sum_i pi_i N(0|0, cov_i)
"""
logpdfmaxs = jax.vmap(
lambda m, c: jax.scipy.stats.multivariate_normal.logpdf(m, m, cov=c)
)(mu, cov)
res = jnp.sum(jnp.exp(logpdfmaxs + logpi))
return res
@functools.partial(jax.jit, static_argnames=["K", "c"])
def partitions_(k, K, c):
"""
PRECONDITION: c = combinations(k+K-1, k, exact=True)
TODO come up with a way to NOT have to pass c, while still allowing JIT
"""
def _inner(k, buffer, pos, carry):
obuffer, opos = carry
if pos == K - 1:
updated_buffer = buffer.at[pos].set(k)
return (obuffer.at[opos, :].set(updated_buffer), opos + 1)
else:
return jax.lax.fori_loop(
0,
k + 1,
lambda i, carry: _inner(
k - i, buffer.at[pos].set(i), pos + 1, carry=carry
),
carry,
)
buffer = jnp.zeros(K, dtype=jnp.int32)
carry = (jnp.zeros((c, K), dtype=jnp.int32), 0)
obuffer, opos = _inner(k=k, pos=0, buffer=buffer, carry=carry)
return obuffer
def comb(n, k):
"""Compute n choose k via gamma function"""
return jnp.exp(
jax.scipy.special.gammaln(n + 1)
- jax.scipy.special.gammaln(k + 1)
- jax.scipy.special.gammaln(n - k + 1)
)
def logfactorial(n):
return jax.scipy.special.gammaln(n + 1)
def log_multinomial_coeff(ks):
"""Multinomial coefficient approximated with log-gamma-function"""
m = sum(ks)
return jax.scipy.special.gammaln(m + 1) - sum(
jax.scipy.special.gammaln(k + 1) for k in ks
)
def expectation_of_powers(logpi, mu, cov, k, K, n_partitions):
"""Compute the closed form expectation of powers of a GMM.
Equation 11 in the cited paper
I would *really* like to not pass `K` and `n_partitions` as arguments, but I don't know how to do that with JIT
PRECONDITION:
sum(exp(logpi)) == 1
n_partitions == int(round(comb(k+K-1, k))
"""
assert logpi.ndim == 1
assert len(mu) == len(cov) == K
inv_covs = [jnp.linalg.inv(c) for c in cov]
# Iterate over j1, ..., jK >= 0 where j1 + ... + jK = k
result = 0.0
partitions = partitions_(k, K, n_partitions)
for j in partitions:
# Sum over i
log_inner_summands = jnp.zeros(K)
for i in range(K):
cov_combined = jnp.linalg.inv(
inv_covs[i] + sum(j[t] * inv_covs[t] for t in range(K))
)
mu_combined = cov_combined @ (
inv_covs[i] @ mu[i] + sum(j[t] * inv_covs[t] @ mu[t] for t in range(K))
)
# Gaussian ratio
log_N_ratio = jax.scipy.stats.multivariate_normal.logpdf(
0, mean=mu[i], cov=cov[i]
) - jax.scipy.stats.multivariate_normal.logpdf(
0, mean=mu_combined, cov=cov_combined
)
# Product term
log_product_term = sum(
j[t]
* (
logpi[t]
+ jax.scipy.stats.multivariate_normal.logpdf(
0, mean=mu[t], cov=cov[t]
)
)
for t in range(K)
)
# Weighted sum
s = logpi[i] + log_N_ratio + log_product_term
log_inner_summands = log_inner_summands.at[i].set(s)
result += jnp.exp(
log_multinomial_coeff(j) + jax.nn.logsumexp(log_inner_summands)
)
return result
@functools.partial(jax.jit, static_argnames=["order", "K", "combss"])
def gmm_entropy_legendre_jax_(logpi, mu, cov, order, K, combss):
"""Entropy of a GMM using legendre polynomial approximation
Formula due to Dahlke and Pacheko
See equation 16 in the cited paper
PRECONDITIONS:
len(logpi) == K
combss = [combinations(k+K-1, k, exact=True) for k in range(order+1)]
"""
assert logpi.ndim == 1
assert mu.ndim == 2
d = mu.shape[1]
assert mu.shape == (K, d), f"mu.shape={mu.shape}, K={K}, d={d}"
assert cov.ndim == 3
assert cov.shape == (K, d, d)
N = order # alias
a = (
gmm_max(logpi, mu, cov) + 1e-3
) # a cutoff that must be larger than the maximum density
loga = jnp.log(a)
Ep_px = [
expectation_of_powers(logpi, mu, cov, k, K, combs)
for k, combs in zip(range(0, N + 1), combss)
]
summands = jnp.zeros(N + 1)
for n in range(0, N + 1):
coeffs_sum = jax.lax.fori_loop(
0,
n + 1,
lambda j, acc: acc
+ (-1) ** (n + j)
* ((j + 1) * loga - 1)
* jnp.exp(
logfactorial(n + j) - logfactorial(n - j) - 2 * logfactorial(j + 1)
),
0.0,
)
coeffs2_sum = jnp.sum(
jnp.array(
[
(-1) ** (n + j)
* jnp.exp(
logfactorial(n + j) - logfactorial(n - j) - 2 * logfactorial(j)
)
* Ep_px[j]
/ a**j
for j in range(n + 1)
]
)
)
summands = summands.at[n].set((2 * n + 1) * coeffs_sum * coeffs2_sum)
return -jnp.sum(summands)
def gmm_entropy_legendre_jax(logpi, mu, cov, order):
"""Entropy of a GMM using legendre polynomial approximation
Formula due to Dahlke and Pacheko
See equation 16 in the cited paper (or the corresponding formula in proof of theorem 4.5)
"""
assert mu.ndim == 2
K = mu.shape[0]
D = mu.shape[1]
assert logpi.shape == (K,), f"logpi.shape={logpi.shape}, K={K}"
assert cov.shape == (K, D, D), f"cov.shape={cov.shape}, K={K}, D={D}"
combss = tuple(int(round(comb(k + K - 1, k))) for k in range(order + 1))
return gmm_entropy_legendre_jax_(logpi, mu, cov, order=order, K=K, combss=combss)
if __name__ == "__main__":
order = 8
def mvn_entropy_exact(cov):
"""The entropy of a single Gaussian with covariance matrix `cov`."""
assert cov.ndim == 2
k = cov.shape[0]
assert cov.shape == (k, k)
return (k / 2) * (1 + jnp.log(2 * jnp.pi)) + 0.5 * jnp.log(jnp.linalg.det(cov))
logpi = jnp.array([0.0])
mu = jnp.array([[10.0, -20]])
cov = jnp.array([[[2, 0], [0, 7]]])
a = mvn_entropy_exact(cov[0])
print(a)
a = gmm_entropy_legendre_jax(logpi, mu, cov, order=order)
print(a)
a = mvn_entropy_exact(jnp.array([[1.0]]))
print(a)
a = gmm_entropy_legendre_jax(
jnp.array([0.0]), jnp.array([[0.0]]), jnp.array([[[1.0]]]), order=order
)
print(a)
a = gmm_entropy_legendre_jax(
jnp.log(jnp.ones(2) / 2),
jnp.array([[0.0], [0.0]]),
jnp.array([[[1.0]], [[1.0]]]),
order=order,
)
print(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment