Skip to content

Instantly share code, notes, and snippets.

@jonsedar
Last active December 3, 2024 12:23
Show Gist options
  • Save jonsedar/a2c355df5a8c888768dbec7e1fe1f7a6 to your computer and use it in GitHub Desktop.
Save jonsedar/a2c355df5a8c888768dbec7e1fe1f7a6 to your computer and use it in GitHub Desktop.
Rough & ready MRE to demo shape / dims issue seen when using MVNormal
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 900_MRE_MVN_issue\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"from pymc.testing import assert_no_rvs\n",
"RNG = np.random.default_rng(seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[3.68664592, 0.96080471],\n",
" [5.75719971, 6.9626818 ],\n",
" [0.38634088, 0.73920536],\n",
" [3.08897834, 1.98130835],\n",
" [2.67299306, 1.15830308]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N = 50\n",
"marg = RNG.lognormal(mean=np.exp(0), sigma=1, size=(N, 2))\n",
"marg[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"COORDS = dict(\n",
" m_nm=[\"m0\", \"m1\"],\n",
" mhat_nm=[\"m0hat\", \"m1hat\"],\n",
" u_nm=[\"u0\", \"u1\"],\n",
" c_nm=[\"c0\", \"c1\"],\n",
" chat_nm=[\"c0hat\", \"c1hat\"],\n",
" oid = np.array([f'o{i}' for i in range(N)])\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Issue seen in `ModelA`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "pot_chat has 1 dims but 2 dim labels were provided.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 38\u001b[0m\n\u001b[1;32m 32\u001b[0m c_d \u001b[38;5;241m=\u001b[39m pm\u001b[38;5;241m.\u001b[39mMvNormal\u001b[38;5;241m.\u001b[39mdist(mu\u001b[38;5;241m=\u001b[39mpt\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;241m2\u001b[39m), chol\u001b[38;5;241m=\u001b[39mchol, shape\u001b[38;5;241m=\u001b[39m(N, \u001b[38;5;241m2\u001b[39m))\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# 6. Evidence transformed C against Latent Copula using Potential\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# because pymc TypeError: Variables that depend on other nodes\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# cannot be used for observed data (c)\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mpm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPotential\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpot_chat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc_d\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarn_rvs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdims\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moid\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mchat_nm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge/envs/oreum_copula/lib/python3.11/site-packages/pymc/model/core.py:2403\u001b[0m, in \u001b[0;36mPotential\u001b[0;34m(name, var, model, dims)\u001b[0m\n\u001b[1;32m 2401\u001b[0m var\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mname_for(name)\n\u001b[1;32m 2402\u001b[0m model\u001b[38;5;241m.\u001b[39mpotentials\u001b[38;5;241m.\u001b[39mappend(var)\n\u001b[0;32m-> 2403\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_named_variable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvar\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdims\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2405\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpymc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprinting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m str_for_potential_or_deterministic\n\u001b[1;32m 2407\u001b[0m var\u001b[38;5;241m.\u001b[39mstr_repr \u001b[38;5;241m=\u001b[39m types\u001b[38;5;241m.\u001b[39mMethodType(\n\u001b[1;32m 2408\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(str_for_potential_or_deterministic, dist_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPotential\u001b[39m\u001b[38;5;124m\"\u001b[39m), var\n\u001b[1;32m 2409\u001b[0m )\n",
"File \u001b[0;32m~/miniforge/envs/oreum_copula/lib/python3.11/site-packages/pymc/model/core.py:1537\u001b[0m, in \u001b[0;36mModel.add_named_variable\u001b[0;34m(self, var, dims)\u001b[0m\n\u001b[1;32m 1535\u001b[0m \u001b[38;5;66;03m# This check implicitly states that only vars with .ndim attribute can have dims\u001b[39;00m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m var\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(dims):\n\u001b[0;32m-> 1537\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvar\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvar\u001b[38;5;241m.\u001b[39mndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dims but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(dims)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dim labels were provided.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1539\u001b[0m )\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_vars_to_dims[var\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m dims\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_vars[var\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m var\n",
"\u001b[0;31mValueError\u001b[0m: pot_chat has 1 dims but 2 dim labels were provided."
]
}
],
"source": [
"with pm.Model(coords=COORDS) as mdla:\n",
" \n",
" # 0. Create (Mutable)Data containers for obs\n",
" m = pm.Data(\"m\", marg, dims=(\"oid\", \"m_nm\"))\n",
" \n",
" # 1. Define marginals and Evidence against observed\n",
" # Here we need to establish a LogNormal dist that we subsequently \n",
" # use to transform the observations. So we have to \"evidence\" the \n",
" # created dist object using a Potential (and minimise logp). \n",
" m_d = pm.LogNormal.dist(mu=0, sigma=1, shape=(N, 2))\n",
" _ = pm.Potential(\"pot_mhat\", pm.logp(m_d, m), dims=(\"oid\", \"mhat_nm\"))\n",
"\n",
" # 2. Transformation path pt1: Observed -> Uniform via Marginal CDF\n",
" u = pm.Deterministic(\"u\", pt.exp(pm.logcdf(m_d, m)), dims=(\"oid\", \"u_nm\"))\n",
"\n",
" # 3. Transformation path pt2: Uniform -> Normal via Normal InvCDF\n",
" n_d = pm.Normal.dist(mu=0., sigma=1., shape=(N, 2))\n",
" c = pm.Deterministic(\"c\", pm.icdf(n_d, u), dims=(\"oid\", \"c_nm\"))\n",
"\n",
" # 4. Create Latent Copula dist using a 2D MvNormal\n",
" sd = pm.InverseGamma.dist(alpha=5.0, beta=4.0)\n",
" chol, corr_, stds_ = pm.LKJCholeskyCov(\"lkjcc\", n=2, eta=2, sd_dist=sd)\n",
" c_d = pm.MvNormal.dist(mu=pt.zeros(2), chol=chol, shape=(N, 2))\n",
"\n",
" # 5. Evidence transformed C against Latent Copula using Potential\n",
" # because pymc TypeError: Variables that depend on other nodes\n",
" # cannot be used for observed data (c)\n",
" _ = pm.Potential(\"pot_chat\", pm.logp(c_d, c, warn_rvs=True), dims=(\"oid\", \"chat_nm\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'mdla' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m display(pm\u001b[38;5;241m.\u001b[39mmodel_to_graphviz(\u001b[43mmdla\u001b[49m, formatting\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplain\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 2\u001b[0m display(\u001b[38;5;28mdict\u001b[39m(unobserved\u001b[38;5;241m=\u001b[39mmdla\u001b[38;5;241m.\u001b[39munobserved_RVs, observed\u001b[38;5;241m=\u001b[39mmdla\u001b[38;5;241m.\u001b[39mobserved_RVs))\n\u001b[1;32m 3\u001b[0m assert_no_rvs(mdla\u001b[38;5;241m.\u001b[39mlogp())\n",
"\u001b[0;31mNameError\u001b[0m: name 'mdla' is not defined"
]
}
],
"source": [
"display(pm.model_to_graphviz(mdla, formatting=\"plain\"))\n",
"display(dict(unobserved=mdla.unobserved_RVs, observed=mdla.observed_RVs))\n",
"assert_no_rvs(mdla.logp())\n",
"mdla.debug(fn=\"random\", verbose=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Issue not seen in `ModelB` (included so you can see the same architecture working)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with pm.Model(coords=COORDS) as mdlb:\n",
" # NOTE: exactly the same as mdla except for section [#4] below:\n",
" \n",
" # 0. Create (Mutable)Data containers for obs\n",
" m = pm.Data(\"m\", marg, dims=(\"oid\", \"m_nm\"))\n",
" \n",
" # 1. Define marginals and Evidence against observed\n",
" # Here we need to establish a LogNormal dist that we subsequently \n",
" # use to transform the observations. So we have to \"evidence\" the \n",
" # created dist object using a Potential (and minimise logp). \n",
" m_d = pm.LogNormal.dist(mu=0, sigma=1, shape=(N, 2))\n",
" _ = pm.Potential(\"pot_mhat\", pm.logp(m_d, m), dims=(\"oid\", \"mhat_nm\"))\n",
"\n",
" # 2. Transformation path pt1: Observed -> Uniform via Marginal CDF\n",
" u = pm.Deterministic(\"u\", pt.exp(pm.logcdf(m_d, m)), dims=(\"oid\", \"u_nm\"))\n",
"\n",
" # 3. Transformation path pt2: Uniform -> Normal via Normal InvCDF\n",
" n_d = pm.Normal.dist(mu=0., sigma=1., shape=(N, 2))\n",
" c = pm.Deterministic(\"c\", pm.icdf(n_d, u), dims=(\"oid\", \"c_nm\"))\n",
"\n",
" # 4. Create Latent Copula dist using a 2D Normal\n",
" sd = pm.InverseGamma.dist(alpha=5.0, beta=4.0)\n",
" chol, corr_, stds_ = pm.LKJCholeskyCov(\"lkjcc\", n=2, eta=2, sd_dist=sd)\n",
" c_d = pm.Normal.dist(mu=pt.zeros(2), sigma=pt.diag(corr_[::-1]), shape=(N, 2))\n",
" \n",
" # 5. Evidence transformed C against Latent Copula using Potential\n",
" # because pymc TypeError: Variables that depend on other nodes\n",
" # cannot be used for observed data (c)\n",
" _ = pm.Potential(\"pot_chat\", pm.logp(c_d, c, warn_rvs=True), dims=(\"oid\", \"chat_nm\"))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 12.0.0 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"531pt\" height=\"501pt\"\n",
" viewBox=\"0.00 0.00 531.00 501.03\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 497.03)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-497.03 527,-497.03 527,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>clusteroid (50) x m_nm (2)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M105,-387.03C105,-387.03 210,-387.03 210,-387.03 216,-387.03 222,-393.03 222,-399.03 222,-399.03 222,-473.03 222,-473.03 222,-479.03 216,-485.03 210,-485.03 210,-485.03 105,-485.03 105,-485.03 99,-485.03 93,-479.03 93,-473.03 93,-473.03 93,-399.03 93,-399.03 93,-393.03 99,-387.03 105,-387.03\"/>\n",
"<text text-anchor=\"middle\" x=\"157.38\" y=\"-394.23\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x m_nm (2)</text>\n",
"</g>\n",
"<g id=\"clust2\" class=\"cluster\">\n",
"<title>clusteroid (50) x mhat_nm (2)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M20,-250.52C20,-250.52 141,-250.52 141,-250.52 147,-250.52 153,-256.52 153,-262.52 153,-262.52 153,-367.03 153,-367.03 153,-373.03 147,-379.03 141,-379.03 141,-379.03 20,-379.03 20,-379.03 14,-379.03 8,-373.03 8,-367.03 8,-367.03 8,-262.52 8,-262.52 8,-256.52 14,-250.52 20,-250.52\"/>\n",
"<text text-anchor=\"middle\" x=\"80.12\" y=\"-257.72\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x mhat_nm (2)</text>\n",
"</g>\n",
"<g id=\"clust3\" class=\"cluster\">\n",
"<title>clusteroid (50) x u_nm (2)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M173,-265.78C173,-265.78 273,-265.78 273,-265.78 279,-265.78 285,-271.78 285,-277.78 285,-277.78 285,-351.78 285,-351.78 285,-357.78 279,-363.78 273,-363.78 273,-363.78 173,-363.78 173,-363.78 167,-363.78 161,-357.78 161,-351.78 161,-351.78 161,-277.78 161,-277.78 161,-271.78 167,-265.78 173,-265.78\"/>\n",
"<text text-anchor=\"middle\" x=\"222.62\" y=\"-272.98\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x u_nm (2)</text>\n",
"</g>\n",
"<g id=\"clust4\" class=\"cluster\">\n",
"<title>clusteroid (50) x c_nm (2)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M155,-144.52C155,-144.52 255,-144.52 255,-144.52 261,-144.52 267,-150.52 267,-156.52 267,-156.52 267,-230.52 267,-230.52 267,-236.52 261,-242.52 255,-242.52 255,-242.52 155,-242.52 155,-242.52 149,-242.52 143,-236.52 143,-230.52 143,-230.52 143,-156.52 143,-156.52 143,-150.52 149,-144.52 155,-144.52\"/>\n",
"<text text-anchor=\"middle\" x=\"205\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x c_nm (2)</text>\n",
"</g>\n",
"<g id=\"clust5\" class=\"cluster\">\n",
"<title>cluster3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M343,-253.87C343,-253.87 503,-253.87 503,-253.87 509,-253.87 515,-259.87 515,-265.87 515,-265.87 515,-363.68 515,-363.68 515,-369.68 509,-375.68 503,-375.68 503,-375.68 343,-375.68 343,-375.68 337,-375.68 331,-369.68 331,-363.68 331,-363.68 331,-265.87 331,-265.87 331,-259.87 337,-253.87 343,-253.87\"/>\n",
"<text text-anchor=\"middle\" x=\"503.62\" y=\"-261.07\" font-family=\"Times,serif\" font-size=\"14.00\">3</text>\n",
"</g>\n",
"<g id=\"clust6\" class=\"cluster\">\n",
"<title>cluster2 x 2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M287,-144.52C287,-144.52 369,-144.52 369,-144.52 375,-144.52 381,-150.52 381,-156.52 381,-156.52 381,-230.52 381,-230.52 381,-236.52 375,-242.52 369,-242.52 369,-242.52 287,-242.52 287,-242.52 281,-242.52 275,-236.52 275,-230.52 275,-230.52 275,-156.52 275,-156.52 275,-150.52 281,-144.52 287,-144.52\"/>\n",
"<text text-anchor=\"middle\" x=\"359.12\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">2 x 2</text>\n",
"</g>\n",
"<g id=\"clust7\" class=\"cluster\">\n",
"<title>cluster2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M401,-144.52C401,-144.52 483,-144.52 483,-144.52 489,-144.52 495,-150.52 495,-156.52 495,-156.52 495,-230.52 495,-230.52 495,-236.52 489,-242.52 483,-242.52 483,-242.52 401,-242.52 401,-242.52 395,-242.52 389,-236.52 389,-230.52 389,-230.52 389,-156.52 389,-156.52 389,-150.52 395,-144.52 401,-144.52\"/>\n",
"<text text-anchor=\"middle\" x=\"483.62\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n",
"</g>\n",
"<g id=\"clust8\" class=\"cluster\">\n",
"<title>clusteroid (50) x chat_nm (2)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M213,-8C213,-8 329,-8 329,-8 335,-8 341,-14 341,-20 341,-20 341,-124.52 341,-124.52 341,-130.52 335,-136.52 329,-136.52 329,-136.52 213,-136.52 213,-136.52 207,-136.52 201,-130.52 201,-124.52 201,-124.52 201,-20 201,-20 201,-14 207,-8 213,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"270.75\" y=\"-15.2\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x chat_nm (2)</text>\n",
"</g>\n",
"<!-- m -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>m</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M168,-477.03C168,-477.03 138,-477.03 138,-477.03 132,-477.03 126,-471.03 126,-465.03 126,-465.03 126,-431.53 126,-431.53 126,-425.53 132,-419.53 138,-419.53 138,-419.53 168,-419.53 168,-419.53 174,-419.53 180,-425.53 180,-431.53 180,-431.53 180,-465.03 180,-465.03 180,-471.03 174,-477.03 168,-477.03\"/>\n",
"<text text-anchor=\"middle\" x=\"153\" y=\"-459.73\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n",
"<text text-anchor=\"middle\" x=\"153\" y=\"-443.23\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"153\" y=\"-426.73\" font-family=\"Times,serif\" font-size=\"14.00\">Data</text>\n",
"</g>\n",
"<!-- pot_mhat -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>pot_mhat</title>\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"144.85,-308.8 144.85,-345.25 114.48,-371.03 71.52,-371.03 41.15,-345.25 41.15,-308.8 71.52,-283.02 114.48,-283.02 144.85,-308.8\"/>\n",
"<text text-anchor=\"middle\" x=\"93\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">pot_mhat</text>\n",
"<text text-anchor=\"middle\" x=\"93\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"93\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">Potential</text>\n",
"</g>\n",
"<!-- m&#45;&gt;pot_mhat -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>m&#45;&gt;pot_mhat</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M138.93,-419.32C133.19,-407.91 126.35,-394.32 119.73,-381.16\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"123.03,-379.93 115.41,-372.57 116.78,-383.08 123.03,-379.93\"/>\n",
"</g>\n",
"<!-- u -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>u</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"259.12,-355.78 168.88,-355.78 168.88,-298.28 259.12,-298.28 259.12,-355.78\"/>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">u</text>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- m&#45;&gt;u -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>m&#45;&gt;u</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M167.3,-419.32C175.43,-403.44 185.71,-383.33 194.56,-366.03\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"197.45,-368.07 198.89,-357.57 191.22,-364.88 197.45,-368.07\"/>\n",
"</g>\n",
"<!-- c -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>c</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"259.12,-234.52 168.88,-234.52 168.88,-177.02 259.12,-177.02 259.12,-234.52\"/>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"214\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- u&#45;&gt;c -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>u&#45;&gt;c</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214,-298.06C214,-282.63 214,-263.2 214,-246.24\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"217.5,-246.48 214,-236.48 210.5,-246.48 217.5,-246.48\"/>\n",
"</g>\n",
"<!-- pot_chat -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>pot_chat</title>\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"319.98,-66.28 319.98,-102.74 291.29,-128.52 250.71,-128.52 222.02,-102.74 222.02,-66.28 250.71,-40.5 291.29,-40.5 319.98,-66.28\"/>\n",
"<text text-anchor=\"middle\" x=\"271\" y=\"-95.96\" font-family=\"Times,serif\" font-size=\"14.00\">pot_chat</text>\n",
"<text text-anchor=\"middle\" x=\"271\" y=\"-79.46\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"271\" y=\"-62.96\" font-family=\"Times,serif\" font-size=\"14.00\">Potential</text>\n",
"</g>\n",
"<!-- c&#45;&gt;pot_chat -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>c&#45;&gt;pot_chat</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M227.36,-176.81C232.76,-165.51 239.19,-152.07 245.42,-139.03\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"248.55,-140.6 249.7,-130.07 242.23,-137.58 248.55,-140.6\"/>\n",
"</g>\n",
"<!-- lkjcc -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>lkjcc</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"423\" cy=\"-327.03\" rx=\"84.5\" ry=\"40.66\"/>\n",
"<text text-anchor=\"middle\" x=\"423\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc</text>\n",
"<text text-anchor=\"middle\" x=\"423\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"423\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">_LKJCholeskyCov</text>\n",
"</g>\n",
"<!-- lkjcc_corr -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>lkjcc_corr</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"373.12,-234.52 282.88,-234.52 282.88,-177.02 373.12,-177.02 373.12,-234.52\"/>\n",
"<text text-anchor=\"middle\" x=\"328\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc_corr</text>\n",
"<text text-anchor=\"middle\" x=\"328\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"328\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- lkjcc&#45;&gt;lkjcc_corr -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>lkjcc&#45;&gt;lkjcc_corr</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M393.2,-288.62C381.76,-274.25 368.71,-257.87 357.31,-243.56\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"360.37,-241.78 351.4,-236.14 354.89,-246.14 360.37,-241.78\"/>\n",
"</g>\n",
"<!-- lkjcc_stds -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>lkjcc_stds</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"487.12,-234.52 396.88,-234.52 396.88,-177.02 487.12,-177.02 487.12,-234.52\"/>\n",
"<text text-anchor=\"middle\" x=\"442\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc_stds</text>\n",
"<text text-anchor=\"middle\" x=\"442\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"442\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- lkjcc&#45;&gt;lkjcc_stds -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>lkjcc&#45;&gt;lkjcc_stds</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M429.33,-286.31C431.38,-273.45 433.65,-259.19 435.7,-246.33\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"439.15,-246.92 437.26,-236.5 432.23,-245.82 439.15,-246.92\"/>\n",
"</g>\n",
"<!-- lkjcc_corr&#45;&gt;pot_chat -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>lkjcc_corr&#45;&gt;pot_chat</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M314.64,-176.81C309.24,-165.51 302.81,-152.07 296.58,-139.03\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"299.77,-137.58 292.3,-130.07 293.45,-140.6 299.77,-137.58\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x14b308990>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'unobserved': [lkjcc ~ _lkjcholeskycov(2, 2, InverseGamma(5, 4)),\n",
" u ~ Deterministic(f()),\n",
" c ~ Deterministic(f()),\n",
" lkjcc_corr ~ Deterministic(f(lkjcc)),\n",
" lkjcc_stds ~ Deterministic(f(lkjcc))],\n",
" 'observed': []}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"point={'lkjcc_cholesky-cov-packed__': array([0., 0., 0.])}\n",
"\n",
"No problems found\n"
]
}
],
"source": [
"display(pm.model_to_graphviz(mdlb, formatting=\"plain\"))\n",
"display(dict(unobserved=mdlb.unobserved_RVs, observed=mdlb.observed_RVs))\n",
"assert_no_rvs(mdlb.logp())\n",
"mdlb.debug(fn=\"random\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The watermark extension is already loaded. To reload it, use:\n",
" %reload_ext watermark\n",
"Last updated: Tue Dec 03 2024\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.11.10\n",
"IPython version : 8.29.0\n",
"\n",
"pymc : 5.16.2\n",
"numpy : 1.26.4\n",
"pytensor: 2.25.5\n",
"\n",
"Watermark: 2.5.0\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -w"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "oreum_copula",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
},
"widgets": {
"state": {
"2478528dd27c4eed9219f055bf907997": {
"views": [
{
"cell_index": 67
}
]
},
"c789864f037443ada1d6ad40c44524f4": {
"views": [
{
"cell_index": 68
}
]
}
},
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment