Skip to content

Instantly share code, notes, and snippets.

@giladturok
Last active November 20, 2024 18:55
Show Gist options
  • Select an option

  • Save giladturok/5e579a57502cc07c0287d85b7bd0e574 to your computer and use it in GitHub Desktop.

Select an option

Save giladturok/5e579a57502cc07c0287d85b7bd0e574 to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
from typing import Callable
class ParameterTransforms:
"""Base parameter transforms specification."""
def forward(self, x):
"""Forward transform from unconstrained to constrained space."""
raise NotImplementedError
def inverse(self, x):
"""Inverse transform from constrained to unconstrained space."""
raise NotImplementedError
def log_jacobian(self, x):
"""Compute log Jacobian adjustment for forward transform."""
raise NotImplementedError
class PositiveTransform(ParameterTransform):
"""Positive constraint using softplus transformation."""
def forward(self, x):
return jax.nn.softplus(x)
def inverse(self, x):
return jnp.log(jnp.exp(x) - 1)
def log_jacobian(self, x):
return -jax.nn.softplus(-x)
class Model:
"""Base model with parameter constraint methods."""
@classmethod
def get_parameter_constraints(cls):
"""Return parameter constraints for the model."""
return {}
@classmethod
def _apply_tree_transforms(cls, generic_params, apply_fn):
"""Apply generic tree transforms."""
process_leaf_fn = lambda path, leaf: apply_fn(path[-1], leaf)
return jax.tree_util.tree_map_with_path(process_leaf_fn, generic_params)
@classmethod
def constrain_parameters(cls, params_unc):
"""Apply parameter constraints and compute log Jacobian."""
def constrain_leaf(key, leaf):
"""Apply parameter constraints."""
constraint = constraint_map.get(key)
return constraint.forward(leaf) if constraint else leaf
def compute_leaf_jacobian(key, leaf):
"""Compute log Jacobian of parameter constraints."""
constraint = constraint_map.get(key)
return constraint.log_jacobian(leaf) if constraint else 0.0
constrained_params = cls._apply_tree_constraints(params_unc, constrain_leaf)
log_jacobian_tree = cls._apply_tree_constraints(params_unc, compute_leaf_jacobian)
log_jacobian = jnp.sum(jax.tree_leaves(log_jacobian_tree))
return constrained_params, log_jacobian
@classmethod
def unconstrain_parameters(cls, params):
"""Remove parameter constraints."""
def unconstrain_leaf(key, leaf):
constraint = cls.get_parameter_constraints().get(key)
return constraint.inverse(leaf) if constraint else leaf
return cls._apply_tree_constraints(params, unconstrain_leaf)
@classmethod
def log_density_unconstrained(cls, data, params_unc):
"""Compute log density in unconstrained parameter space."""
params, log_jac = cls.constrain_parameters(params_unc)
log_density = cls.log_density(data, params)
return log_density + log_jac
@classmethod
def log_density(cls, data, params):
"""Compute log density in constrained parameter space."""
raise NotImplementedError
@classmethod
def num_unconstrained_parameters(cls, data):
raise NotImplementedError("Subclasses should implement this method.")
@classmethod
def generated_quantities(cls, data):
return { }
@classmethod
def initial_draw(cls, data, rng):
D = cls.num_unconstrained_parameters(data)
params_unconstrained = jax.random.normal(rng, shape=D)
return params_unconstrained
class GaussianModel(Model):
"""Gaussian distribution model with parameter constraints."""
@classmethod
def get_parameter_constraints(cls):
"""Return Gaussian model parameter constraints."""
return {"scale": PositiveConstraint()}
@classmethod
def log_density(cls, data, params):
"""Compute log density for Gaussian distribution."""
mean = params['mean']
scale = params['scale']
return -0.5 * jnp.sum((mean / scale)**2 + jnp.log(2 * jnp.pi * scale**2))
def main():
# Unconstrained parameters
params_unc = {
'mean': jnp.array([0.1, -0.2]),
'scale': jnp.array([1.0, 2.0])
}
# Compute log density
log_density = GaussianModel.log_density_unconstrained(None, params_unc)
return log_density
@giladturok
Copy link
Author

giladturok commented Nov 20, 2024

MVP for densejax, a Python package that supports inference for arbitrary densities.

This implementation uses JAX's tree utils to cleanly handle parameter constraints. Each parameter constraint (e.g. positive, lower-bound) is a transform that implements a forward, inverse, and log_jacobian method.

Details

Parameter constraints are recursively applied to the Pytree of parameters with the _apply_tree_transforms method. The _apply_tree_transforms is used in the constrain_parameters and unconstrain_parameters methods.

At its core, _apply_tree_transforms traverses over the path and value of each parameter in a Pytree of parameters. The path tells us the name of a variable (e.g. scale) and its value is a JAX array (e.g. [1.2]).

The key idea ( 💡 ) is to check if a parameter's path (name) matches any of constraints specified in get_parameter_constraints.

More precisely: we check if a parameter name (e.g. scale) is matched in the get_parameter_constraints dictionary (e.g. {"scale": PositiveTransform()}. If a match is found, we apply the specificied method (e.g. forward, inverse, or log_jacobian) of the given transform (e.g. PositiveTransform).

This approach is advantageous because (1) the code is compact yet readable and (2) we do not need to track the constraint a parameter may have (and how to apply it) beyond specifying the get_parameter_constraints dictionary.

Possible improvements:

  1. Make get_parameter_constraints an attribute or static? Not sure the best thing to do here...
  2. Clarify terminology for transforms vs constraints

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment