Last active
November 20, 2024 18:55
-
-
Save giladturok/5e579a57502cc07c0287d85b7bd0e574 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| from typing import Callable | |
| class ParameterTransforms: | |
| """Base parameter transforms specification.""" | |
| def forward(self, x): | |
| """Forward transform from unconstrained to constrained space.""" | |
| raise NotImplementedError | |
| def inverse(self, x): | |
| """Inverse transform from constrained to unconstrained space.""" | |
| raise NotImplementedError | |
| def log_jacobian(self, x): | |
| """Compute log Jacobian adjustment for forward transform.""" | |
| raise NotImplementedError | |
| class PositiveTransform(ParameterTransform): | |
| """Positive constraint using softplus transformation.""" | |
| def forward(self, x): | |
| return jax.nn.softplus(x) | |
| def inverse(self, x): | |
| return jnp.log(jnp.exp(x) - 1) | |
| def log_jacobian(self, x): | |
| return -jax.nn.softplus(-x) | |
| class Model: | |
| """Base model with parameter constraint methods.""" | |
| @classmethod | |
| def get_parameter_constraints(cls): | |
| """Return parameter constraints for the model.""" | |
| return {} | |
| @classmethod | |
| def _apply_tree_transforms(cls, generic_params, apply_fn): | |
| """Apply generic tree transforms.""" | |
| process_leaf_fn = lambda path, leaf: apply_fn(path[-1], leaf) | |
| return jax.tree_util.tree_map_with_path(process_leaf_fn, generic_params) | |
| @classmethod | |
| def constrain_parameters(cls, params_unc): | |
| """Apply parameter constraints and compute log Jacobian.""" | |
| def constrain_leaf(key, leaf): | |
| """Apply parameter constraints.""" | |
| constraint = constraint_map.get(key) | |
| return constraint.forward(leaf) if constraint else leaf | |
| def compute_leaf_jacobian(key, leaf): | |
| """Compute log Jacobian of parameter constraints.""" | |
| constraint = constraint_map.get(key) | |
| return constraint.log_jacobian(leaf) if constraint else 0.0 | |
| constrained_params = cls._apply_tree_constraints(params_unc, constrain_leaf) | |
| log_jacobian_tree = cls._apply_tree_constraints(params_unc, compute_leaf_jacobian) | |
| log_jacobian = jnp.sum(jax.tree_leaves(log_jacobian_tree)) | |
| return constrained_params, log_jacobian | |
| @classmethod | |
| def unconstrain_parameters(cls, params): | |
| """Remove parameter constraints.""" | |
| def unconstrain_leaf(key, leaf): | |
| constraint = cls.get_parameter_constraints().get(key) | |
| return constraint.inverse(leaf) if constraint else leaf | |
| return cls._apply_tree_constraints(params, unconstrain_leaf) | |
| @classmethod | |
| def log_density_unconstrained(cls, data, params_unc): | |
| """Compute log density in unconstrained parameter space.""" | |
| params, log_jac = cls.constrain_parameters(params_unc) | |
| log_density = cls.log_density(data, params) | |
| return log_density + log_jac | |
| @classmethod | |
| def log_density(cls, data, params): | |
| """Compute log density in constrained parameter space.""" | |
| raise NotImplementedError | |
| @classmethod | |
| def num_unconstrained_parameters(cls, data): | |
| raise NotImplementedError("Subclasses should implement this method.") | |
| @classmethod | |
| def generated_quantities(cls, data): | |
| return { } | |
| @classmethod | |
| def initial_draw(cls, data, rng): | |
| D = cls.num_unconstrained_parameters(data) | |
| params_unconstrained = jax.random.normal(rng, shape=D) | |
| return params_unconstrained | |
| class GaussianModel(Model): | |
| """Gaussian distribution model with parameter constraints.""" | |
| @classmethod | |
| def get_parameter_constraints(cls): | |
| """Return Gaussian model parameter constraints.""" | |
| return {"scale": PositiveConstraint()} | |
| @classmethod | |
| def log_density(cls, data, params): | |
| """Compute log density for Gaussian distribution.""" | |
| mean = params['mean'] | |
| scale = params['scale'] | |
| return -0.5 * jnp.sum((mean / scale)**2 + jnp.log(2 * jnp.pi * scale**2)) | |
| def main(): | |
| # Unconstrained parameters | |
| params_unc = { | |
| 'mean': jnp.array([0.1, -0.2]), | |
| 'scale': jnp.array([1.0, 2.0]) | |
| } | |
| # Compute log density | |
| log_density = GaussianModel.log_density_unconstrained(None, params_unc) | |
| return log_density |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
MVP for
densejax, a Python package that supports inference for arbitrary densities.This implementation uses JAX's tree utils to cleanly handle parameter constraints. Each parameter constraint (e.g. positive, lower-bound) is a transform that implements a
forward,inverse, andlog_jacobianmethod.Details
Parameter constraints are recursively applied to the Pytree of parameters with the
_apply_tree_transformsmethod. The_apply_tree_transformsis used in theconstrain_parametersandunconstrain_parametersmethods.At its core,
_apply_tree_transformstraverses over the path and value of each parameter in a Pytree of parameters. The path tells us the name of a variable (e.g.scale) and its value is a JAX array (e.g.[1.2]).The key idea ( 💡 ) is to check if a parameter's path (name) matches any of constraints specified in
get_parameter_constraints.More precisely: we check if a parameter name (e.g.
scale) is matched in theget_parameter_constraintsdictionary (e.g.{"scale": PositiveTransform()}. If a match is found, we apply the specificied method (e.g.forward,inverse, orlog_jacobian) of the given transform (e.g.PositiveTransform).This approach is advantageous because (1) the code is compact yet readable and (2) we do not need to track the constraint a parameter may have (and how to apply it) beyond specifying the
get_parameter_constraintsdictionary.Possible improvements:
get_parameter_constraintsan attribute or static? Not sure the best thing to do here...