Created
September 20, 2024 15:03
-
-
Save adrn/e042eb0124bb62b49466a56e014bba53 to your computer and use it in GitHub Desktop.
Notes on information theory
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The Fisher information matrix is the expectation value of the variance of the log-likelihood function, which is often just the expectation of the Hessian of the log-likelihood function (i.e. the second derivatives with respect to the model parameters). That's just a boring definition! For a data set $D$ and model parameters $\\theta$, \n", | |
"$$\n", | |
"\\mathcal{F}_{jk} = -\\mathbb{E}\\left[\\frac{\\partial^2}{\\partial\\theta_j \\, \\partial\\theta_k}\\ln\\mathcal{L}(D \\,|\\, \\theta)\\right]\n", | |
"$$\n", | |
"where $\\mathbb{E}\\left[x\\right]$ is the expectation value of $x$ evaluated at the true parameter values.\n", | |
"Conceptually, it is a matrix that tells us how much information the data $D$ contain about a given parameter, or how much covariance there is in our knowledge of a combination of parameters. \n", | |
"\n", | |
"I think of it as a result of a Taylor expansion of the likelihood function around the maximum likelihood parameter values. It is related to the curvature of the likelihood function around its maximum (because second derivatives are involves). It is not a fundamental quantity that is universally interpretable or meaningful. For example, if a model is a bad representation of data, or if the likelihood surface is extremely structured or multi-modal (i.e. cases where a Taylor expansion to low order is a bad representation of a function), then the Fisher information does not really summarize any useful properties of the data or model. \n", | |
"\n", | |
"# Example: Fitting a line to data\n", | |
"\n", | |
"As an example of how to compute this thing, let's use the ever-loved example of fitting a straight line to data $y$ at positions $x$ with uncertainties $\\sigma_y$ only on the data $y$. Let's generate some fake data to work with:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib as mpl\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ndata = 8\n", | |
"rng = np.random.default_rng(42)\n", | |
"true_pars = rng.normal(0, 2.0, size=2)\n", | |
"\n", | |
"x = np.sort(rng.uniform(0, 10, ndata))\n", | |
"yerr = 10 ** rng.uniform(-1, 0, ndata)\n", | |
"y = rng.normal(true_pars[0] * x + true_pars[1], yerr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Text(0.5, 0, 'x'), Text(0, 0.5, 'y')]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 432x432 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"image/png": { | |
"height": 440, | |
"width": 440 | |
} | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fig, ax = plt.subplots()\n", | |
"ax.errorbar(x, y, yerr, marker=\"o\", ls=\"none\")\n", | |
"ax.set(xlabel=\"x\", ylabel=\"y\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In this case, our model for the data is:\n", | |
"$$\n", | |
"f(x \\,;\\, \\theta) = a \\, x + b \\\\\n", | |
"\\theta = (a, b)\n", | |
"$$\n", | |
"and we will assume that the data are generated from some process such that the uncertainties $\\sigma_y$ are Gaussian (and I'll drop the subscript $y$ because I'm lazy). Our likelihood function for a single data point $(x_i, y_i, \\sigma_i)$ is then:\n", | |
"$$\n", | |
"L(y_i \\,;\\, \\theta) = \\mathcal{N}(y_i \\,|\\, f(x \\,;\\, \\theta), \\sigma_i)\n", | |
"$$\n", | |
"($x_i$ and $\\sigma_i$ aren't really data because we assume we know them perfectly, but they are associated with a given data point $y$ as metadata). \n", | |
"$\\mathcal{N}(w \\,|\\, \\mu, \\sigma)$ is the Normal distribution over $w$ with mean $\\mu$ and standard deviation $\\sigma$.\n", | |
"\n", | |
"In this model, the data points are independent, so the total likelihood $\\mathcal{L}$ is just the product of the likelihoods for each individual data point:\n", | |
"$$\n", | |
"\\begin{align}\n", | |
"\\mathcal{L} &= \\prod_i L(y_i \\,;\\, \\theta)\\\\ \n", | |
"&= \\prod_i \\mathcal{N}(y_i \\,|\\, f_i, \\sigma_i)\n", | |
"\\end{align}\n", | |
"$$\n", | |
"where, as a reminder, $f_i$ is the model predicted value of $y$ at a given location $x_i$ -- you could think of it as the model's prediction for the true value of $y_i$.\n", | |
"In the log, this product is just a sum:\n", | |
"$$\n", | |
"\\ln\\mathcal{L} = \\sum_i \\ln\\mathcal{N}(y_i \\,|\\, f_i, \\sigma_i)\n", | |
"$$\n", | |
"\n", | |
"So now we need to take second derivatives of this thing. The derivative here can be moved inside of the sum, so \n", | |
"$$\n", | |
"\\frac{\\partial^2\\ln\\mathcal{L}}{\\partial \\theta_j \\partial \\theta_k} = \\sum_i \\frac{\\partial^2}{\\partial \\theta_j \\partial \\theta_k}\\ln\\mathcal{N}(y_i \\,|\\, f_i, \\sigma_i)\n", | |
"$$\n", | |
"\n", | |
"To see how this works in this one-dimensional case, let's expand the expression for the log-Normal distribution:\n", | |
"$$\n", | |
"\\ln\\mathcal{N}(y_i \\,|\\, f_i, \\sigma_i) = -\\frac{1}{2} \\left[ \\frac{(y_i - f_i)^2}{\\sigma_i^2} + \\ln(2\\pi\\,\\sigma_i^2) \\right]\n", | |
"$$\n", | |
"(this may seem like a weird way to write it but trust me). Now, we need to take derivatives of this expression with respect to the model parameters $\\theta$. One thing to note is that the second term $\\ln(2\\pi\\,\\sigma^2)$ does not depend on the parameters (it is only a function of the data, i.e. the uncertainties), so the derivatives of this term are zero. The first term is a quadratic function of $y_i$ and $f_i$, but $y_i$ is data and $f_i$ is a function of the model parameters, so the only terms that survive the derivative are those that depend on $f_i$ (i.e. the model predictions):\n", | |
"$$\n", | |
"\\frac{\\partial^2}{\\partial \\theta_j \\partial \\theta_k} \n", | |
" \\ln\\mathcal{N}(y_i \\,|\\, f_i, \\sigma_i) = \n", | |
" \\frac{1}{\\sigma_i^2} \\frac{\\mathrm{d} f_i}{\\mathrm{d}\\theta_j} \n", | |
" \\frac{\\mathrm{d} f_i}{\\mathrm{d}\\theta_k}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Our model $f$ only has two parameters, so we can write this matrix out explicitly, for one data point:\n", | |
"$$\n", | |
"\\begin{align}\n", | |
"&= \\frac{1}{\\sigma_i^2} \\, \\begin{pmatrix}\n", | |
"\\frac{\\mathrm{d} f_i}{\\mathrm{d}a} \\frac{\\mathrm{d} f_i}{\\mathrm{d}a} & \\frac{\\mathrm{d} f_i}{\\mathrm{d}a} \\frac{\\mathrm{d} f_i}{\\mathrm{d}b} \\\\\n", | |
"\\frac{\\mathrm{d} f_i}{\\mathrm{d}b} \\frac{\\mathrm{d} f_i}{\\mathrm{d}a} & \\frac{\\mathrm{d} f_i}{\\mathrm{d}b} \\frac{\\mathrm{d} f_i}{\\mathrm{d}b}\n", | |
"\\end{pmatrix} \\\\\n", | |
"&= \\frac{1}{\\sigma_i^2} \\, \\begin{pmatrix}\n", | |
"x_i^2 & x_i \\\\\n", | |
"x_i & 1 \n", | |
"\\end{pmatrix}\n", | |
"\\end{align}\n", | |
"$$\n", | |
"\n", | |
"And so the Fisher information matrix is just the sum of these matrices over all data points:\n", | |
"$$\n", | |
"\\mathcal{F} = \\sum_i \\frac{1}{\\sigma_i^2} \\, \\begin{pmatrix}\n", | |
"x_i^2 & x_i \\\\\n", | |
"x_i & 1\n", | |
"\\end{pmatrix}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is an interesting (but simple) result: the Fisher information here does not even contain the model parameter values! That is because our model $f$ is _linear_ in the model parameters and our likelihood is Gaussian, so the curvature is a constant. Let's look at how to compute this given our toy data above." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(2, 2, 8)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# This is the matrix component of the Fisher information from above\n", | |
"M = np.array([[x**2, x], [x, np.ones_like(x)]])\n", | |
"M.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(2, 2)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"F = np.sum(1 / yerr**2 * M, axis=-1)\n", | |
"F.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[10822.60583843, 1226.64589812],\n", | |
" [ 1226.64589812, 157.35632654]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"F" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"OK so we've computed the Fisher information! ...So what? It turns out that the inverse of the Fisher information matrix is a useful matrix itself: it is the covariance matrix of the maximum likelihood parameter estimates. This is a useful property because it tells us the expected precision of our parameter estimates given the data we have. If the Fisher information is large, then the parameter uncertainties (i.e. the inverses) are small, so the parameters are well-constrained by the data. Let's look at the inverse of the Fisher information matrix for our toy data above:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0.00079335, -0.00618439],\n", | |
" [-0.00618439, 0.05456446]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Finv = np.linalg.inv(F)\n", | |
"Finv" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"If we ignore the covariances (i.e. the off-diagonal terms), this tells us the expected errors we would get on the parameter values if we fit the data with a straight line. The diagonal terms are the variances of the parameter estimates, and the square roots of these are the expected standard deviations of the parameter estimates:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.02816638, 0.23359036])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.sqrt(np.diag(Finv))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's now do a maximum likelihood fit using least-squares algebra and compare the results to the Fisher information matrix. With least-squares:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Design matrix:\n", | |
"A = np.stack((x, np.ones_like(x)), axis=-1)\n", | |
"\n", | |
"# Data covariance matrix:\n", | |
"Cinv = np.diag(1 / yerr**2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Here is the standard matrix algebra for computing the maximum likelihood parameter estimates for a linear model (`mle_pars` below). In this case, the covariance matrix of the parameter estimates, $P$, is the inverse of the Fisher information matrix! (BTW, as an aside, it's generally bad to construct the data covariance matrix this way and to do the matrix inverse explicitly here — a better way is to use, e.g., `numpy.linalg.leastsq` or `numpy.linalg.solve`)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0.58499692, -1.87999557])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"P = np.linalg.inv(A.T @ Cinv @ A)\n", | |
"mle_pars = P @ A.T @ Cinv @ y\n", | |
"mle_pars" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.02816638, 0.23359036])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Note that these are the same values as sqrt(diag(Finv))!\n", | |
"np.sqrt(np.diag(P))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# More generally: Multi-dimensional and nonlinear (but still Gaussian) models\n", | |
"\n", | |
"The example above is a special case where we can solve for the Fisher information analytically and our model is linear. We are often in more complex situations in astronomy. One example is our project (where our likelihood function involves orbit integrations; here the model is nonlinear and the data are multi-dimensional). Another example is solving the Kepler problem by fitting a two-body orbit to, e.g., radial velocity data (model is nonlinear but data are still one dimensional). \n", | |
"\n", | |
"In general, our data may be a vector $\\boldsymbol{y}$ and our model may be a function $\\boldsymbol{f}(\\boldsymbol{x} \\,;\\, \\boldsymbol{\\theta})$ where $\\boldsymbol{x}$ is a vector of independent variables and $\\boldsymbol{\\theta}$ is a vector of model parameters. If the likelihood is non-Gaussian, there's not much we can do analytically. But if the likelihood is Gaussian, even with a nonlinear model, we can simplify the Fisher information matrix calculation by using the chain rule for derivatives.\n", | |
"\n", | |
"In this case, our likelihood for one data point $\\boldsymbol{y}_i$ (note the vector $y$) is:\n", | |
"$$\n", | |
"L(\\boldsymbol{y}_i \\,;\\, \\boldsymbol{\\theta}) = \\mathcal{N}(\\boldsymbol{y}_i \\,|\\, \\boldsymbol{f}_i, \\boldsymbol{\\Sigma}_i)\n", | |
"$$\n", | |
"where now $\\boldsymbol{\\Sigma}_i$ is the covariance matrix of the data point $\\boldsymbol{y}_i$, and $\\boldsymbol{f}_i$ is a vector of model-predicted \"true\" values for the data. We can expand the log-Normal expression in the same way as we did for the one-dimensional case above, but now it involves some matrix algebra:\n", | |
"$$\n", | |
"\\ln L(\\boldsymbol{y}_i \\,;\\, \\boldsymbol{\\theta}) = \n", | |
" -\\frac{1}{2} \\left[ (\\boldsymbol{y}_i - \\boldsymbol{f}_i)^\\top \\boldsymbol{\\Sigma}_i^{-1} (\\boldsymbol{y}_i - \\boldsymbol{f}_i) + \\ln\\det\\boldsymbol{\\Sigma}_i + N \\ln(2\\pi) \\right]\n", | |
"$$\n", | |
"where $\\det$ is the determinant of a matrix and $N$ is the number of data points. We again want derivatives of this expression with respect to model parameters $\\boldsymbol{\\theta}$, so once again we can ignore the second and third terms because they are independent of the model parameters. The first term is a quadratic form in the data and model predictions, so the only terms that survive the derivative are those that depend on the model predictions:\n", | |
"$$\n", | |
"\\frac{\\partial^2}{\\partial \\theta_j \\partial \\theta_k} \n", | |
" \\ln L(\\boldsymbol{y}_i \\,;\\, \\boldsymbol{\\theta}) = \n", | |
" \\frac{1}{2} \\left[ \\frac{\\partial \\boldsymbol{f}_i}{\\partial \\theta_j} \\boldsymbol{\\Sigma}_i^{-1} \\frac{\\partial \\boldsymbol{f}_i}{\\partial \\theta_k} \\right]\n", | |
"$$\n", | |
"This is the quantity you might see in, for example, [Bonaca & Hogg 2018](https://ui.adsabs.harvard.edu/abs/2018ApJ...867..101B/abstract) or the draft by Sophia Lilleengen of our upcoming paper on stream information theory.\n", | |
"\n", | |
"Let's now do an example of this for a two-dimensional model with some nonlinear model parameters. For example, a sinusoidal model of two variables with unknown amplitudes and frequency:\n", | |
"$$\n", | |
"\\boldsymbol{f}(x \\,;\\, \\boldsymbol{\\theta}) = \n", | |
" \\begin{pmatrix}\n", | |
" a \\, \\sin(2\\pi \\, \\nu \\, x) \\\\\n", | |
" b \\, \\cos(2\\pi \\, \\nu \\, x)\n", | |
" \\end{pmatrix}\n", | |
"$$\n", | |
"\n", | |
"We again generate some toy data to play with (and yes I'm being naughty by using the same variable names as above):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((16, 2, 2), (16, 2))" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ndata = 16\n", | |
"rng = np.random.default_rng(8675309)\n", | |
"true_pars = np.abs(rng.normal(0, 2.0, size=3))\n", | |
"\n", | |
"x = np.sort(rng.uniform(0, 2 / true_pars[2], ndata))\n", | |
"yerr = 10 ** rng.uniform(-2, -0.5, (ndata, 2)) * true_pars[:2]\n", | |
"\n", | |
"y = np.stack(\n", | |
" (\n", | |
" rng.normal(true_pars[0] * np.sin(2 * np.pi * true_pars[2] * x), yerr[:, 0]),\n", | |
" rng.normal(true_pars[1] * np.cos(2 * np.pi * true_pars[2] * x), yerr[:, 1]),\n", | |
" ),\n", | |
" axis=-1,\n", | |
")\n", | |
"\n", | |
"# The inverse covariance matrix of the data (which here is diagonal):\n", | |
"Cinv = np.stack([np.diag(1 / yerr[i] ** 2) for i in range(ndata)], axis=0)\n", | |
"Cinv.shape, y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 720x360 with 2 Axes>" | |
] | |
}, | |
"metadata": { | |
"image/png": { | |
"height": 368, | |
"width": 728 | |
} | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", | |
"\n", | |
"for i in range(2):\n", | |
" axes[i].errorbar(x, y[..., i], yerr[..., i], marker=\"o\", ls=\"none\")\n", | |
" axes[i].set(xlabel=\"x\", ylabel=f\"$y_{i+1}$\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now let's implement this model in JAX so we can make use of the auto-differentiation capabilities. We'll again compute the Fisher information matrix and compare it to the maximum likelihood parameter estimates." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"\n", | |
"jax.config.update(\"jax_enable_x64\", True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@jax.jit\n", | |
"def model_f(x, pars):\n", | |
" return jnp.stack(\n", | |
" (\n", | |
" pars[0] * jnp.sin(2 * jnp.pi * pars[2] * x),\n", | |
" pars[1] * jnp.cos(2 * jnp.pi * pars[2] * x),\n", | |
" ),\n", | |
" axis=-1,\n", | |
" )\n", | |
"\n", | |
"\n", | |
"@jax.jit\n", | |
"def ln_likelihood_one_star(pars, x, y, Cinv):\n", | |
" y_model = model_f(x, pars)\n", | |
" dy = y - y_model\n", | |
" return -0.5 * dy.T @ Cinv @ dy\n", | |
"\n", | |
"\n", | |
"ln_likelihood_helper = jax.vmap(ln_likelihood_one_star, in_axes=(None, 0, 0, 0))\n", | |
"\n", | |
"\n", | |
"@jax.jit\n", | |
"def ln_likelihood(pars, x, y, Cinv):\n", | |
" return jnp.sum(ln_likelihood_helper(pars, x, y, Cinv))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's try evaluating our log-likelihood function at the true parameters:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(-11.87854157, dtype=float64)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ln_likelihood(true_pars, x, y, Cinv)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array([[ 4.25782094e+05, -0.00000000e+00, -7.83052999e+06],\n", | |
" [-0.00000000e+00, 6.54218186e+03, 6.83218150e+05],\n", | |
" [-7.83052999e+06, 6.83218150e+05, 6.17018283e+08]], dtype=float64)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"F = -jax.hessian(ln_likelihood)(true_pars, x, y, Cinv)\n", | |
"F" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"And once again, we can turn this into a prediction for the uncertainty on our model parameters:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([1.78625210e-03, 1.34166769e-02, 4.98967477e-05])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Finv = np.linalg.inv(F)\n", | |
"Fisher_param_uncertainties = np.sqrt(np.diag(Finv))\n", | |
"Fisher_param_uncertainties" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In this case, because our model is nonlinear, we can't use the least-squares expressions above to analytically compute the maximum likelihood parameter estimates and precisions. Instead, we will specify a prior and use MCMC to generate samples from the posterior distribution. We can then compare the covariance matrix of the samples to the inverse of the Fisher information matrix (or just the diagonal terms, the parameter uncertainties)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@jax.jit\n", | |
"def ln_prior(pars):\n", | |
" lp = 0.0\n", | |
"\n", | |
" # Gaussian priors on the amplitudes with mean=0, stddev=10:\n", | |
" lp += jax.scipy.stats.norm.logpdf(pars[0], 0, 10)\n", | |
" lp += jax.scipy.stats.norm.logpdf(pars[1], 0, 10)\n", | |
"\n", | |
" # Uniform prior on the frequency:\n", | |
" lp += jax.scipy.stats.uniform.logpdf(pars[2], 0, 1)\n", | |
" lp = jnp.where((pars[2] > 0) & (pars[2] < 1), lp, -jnp.inf)\n", | |
"\n", | |
" return lp\n", | |
"\n", | |
"\n", | |
"@jax.jit\n", | |
"def ln_posterior(pars, x, y, Cinv):\n", | |
" return ln_likelihood(pars, x, y, Cinv) + ln_prior(pars)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import blackjax\n", | |
"\n", | |
"rng_key, warmup_key, sample_key = jax.random.split(jax.random.PRNGKey(42), 3)\n", | |
"init_pars = true_pars\n", | |
"func = lambda pars: ln_posterior(pars, x, y, Cinv)\n", | |
"\n", | |
"warmup = blackjax.window_adaptation(blackjax.nuts, func)\n", | |
"(state, parameters), _ = warmup.run(warmup_key, init_pars, num_steps=1000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def inference_loop(rng_key, kernel, initial_state, num_samples):\n", | |
" @jax.jit\n", | |
" def one_step(state, rng_key):\n", | |
" state, _ = kernel(rng_key, state)\n", | |
" return state, state\n", | |
"\n", | |
" keys = jax.random.split(rng_key, num_samples)\n", | |
" _, states = jax.lax.scan(one_step, initial_state, keys)\n", | |
"\n", | |
" return states" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"kernel = blackjax.nuts(func, **parameters).step\n", | |
"states = inference_loop(sample_key, kernel, state, 10_000)\n", | |
"\n", | |
"mcmc_samples = states.position" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First, let's look at the parameter uncertainties as estimated by MCMC:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array([1.80400079e-03, 1.35137301e-02, 5.07659010e-05], dtype=float64)" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.std(mcmc_samples, axis=0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's compare that to our Fisher approach above:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([1.78625210e-03, 1.34166769e-02, 4.98967477e-05])" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Fisher_param_uncertainties" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"These are very similar, as we expect!\n", | |
"\n", | |
"We can also look at the full covariance matrix:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 3.25474432e-06, -4.52384312e-06, 4.79426144e-08],\n", | |
" [-4.52384312e-06, 1.82639165e-04, -2.74545006e-07],\n", | |
" [ 4.79426144e-08, -2.74545006e-07, 2.57743445e-09]])" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.cov(mcmc_samples.T)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 3.19069658e-06, -4.78172971e-06, 4.57876380e-08],\n", | |
" [-4.78172971e-06, 1.80007218e-04, -2.60004737e-07],\n", | |
" [ 4.57876380e-08, -2.60004737e-07, 2.48968543e-09]])" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Finv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "py312", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment