Last active
December 3, 2024 12:23
-
-
Save jonsedar/a2c355df5a8c888768dbec7e1fe1f7a6 to your computer and use it in GitHub Desktop.
Rough & ready MRE to demo shape / dims issue seen when using MVNormal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 900_MRE_MVN_issue\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Setup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"jupyter": { | |
"outputs_hidden": false | |
}, | |
"slideshow": { | |
"slide_type": "skip" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pymc as pm\n", | |
"import pytensor.tensor as pt\n", | |
"from pymc.testing import assert_no_rvs\n", | |
"RNG = np.random.default_rng(seed=42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[3.68664592, 0.96080471],\n", | |
" [5.75719971, 6.9626818 ],\n", | |
" [0.38634088, 0.73920536],\n", | |
" [3.08897834, 1.98130835],\n", | |
" [2.67299306, 1.15830308]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"N = 50\n", | |
"marg = RNG.lognormal(mean=np.exp(0), sigma=1, size=(N, 2))\n", | |
"marg[:5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"COORDS = dict(\n", | |
" m_nm=[\"m0\", \"m1\"],\n", | |
" mhat_nm=[\"m0hat\", \"m1hat\"],\n", | |
" u_nm=[\"u0\", \"u1\"],\n", | |
" c_nm=[\"c0\", \"c1\"],\n", | |
" chat_nm=[\"c0hat\", \"c1hat\"],\n", | |
" oid = np.array([f'o{i}' for i in range(N)])\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Issue seen in `ModelA`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "ValueError", | |
"evalue": "pot_chat has 1 dims but 2 dim labels were provided.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[8], line 38\u001b[0m\n\u001b[1;32m 32\u001b[0m c_d \u001b[38;5;241m=\u001b[39m pm\u001b[38;5;241m.\u001b[39mMvNormal\u001b[38;5;241m.\u001b[39mdist(mu\u001b[38;5;241m=\u001b[39mpt\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;241m2\u001b[39m), chol\u001b[38;5;241m=\u001b[39mchol, shape\u001b[38;5;241m=\u001b[39m(N, \u001b[38;5;241m2\u001b[39m))\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# 6. Evidence transformed C against Latent Copula using Potential\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# because pymc TypeError: Variables that depend on other nodes\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# cannot be used for observed data (c)\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mpm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPotential\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpot_chat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc_d\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarn_rvs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdims\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moid\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mchat_nm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge/envs/oreum_copula/lib/python3.11/site-packages/pymc/model/core.py:2403\u001b[0m, in \u001b[0;36mPotential\u001b[0;34m(name, var, model, dims)\u001b[0m\n\u001b[1;32m 2401\u001b[0m var\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mname_for(name)\n\u001b[1;32m 2402\u001b[0m model\u001b[38;5;241m.\u001b[39mpotentials\u001b[38;5;241m.\u001b[39mappend(var)\n\u001b[0;32m-> 2403\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_named_variable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvar\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdims\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2405\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpymc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprinting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m str_for_potential_or_deterministic\n\u001b[1;32m 2407\u001b[0m var\u001b[38;5;241m.\u001b[39mstr_repr \u001b[38;5;241m=\u001b[39m types\u001b[38;5;241m.\u001b[39mMethodType(\n\u001b[1;32m 2408\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(str_for_potential_or_deterministic, dist_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPotential\u001b[39m\u001b[38;5;124m\"\u001b[39m), var\n\u001b[1;32m 2409\u001b[0m )\n", | |
"File \u001b[0;32m~/miniforge/envs/oreum_copula/lib/python3.11/site-packages/pymc/model/core.py:1537\u001b[0m, in \u001b[0;36mModel.add_named_variable\u001b[0;34m(self, var, dims)\u001b[0m\n\u001b[1;32m 1535\u001b[0m \u001b[38;5;66;03m# This check implicitly states that only vars with .ndim attribute can have dims\u001b[39;00m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m var\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(dims):\n\u001b[0;32m-> 1537\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvar\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvar\u001b[38;5;241m.\u001b[39mndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dims but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(dims)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dim labels were provided.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1539\u001b[0m )\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_vars_to_dims[var\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m dims\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_vars[var\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m var\n", | |
"\u001b[0;31mValueError\u001b[0m: pot_chat has 1 dims but 2 dim labels were provided." | |
] | |
} | |
], | |
"source": [ | |
"with pm.Model(coords=COORDS) as mdla:\n", | |
" \n", | |
" # 0. Create (Mutable)Data containers for obs\n", | |
" m = pm.Data(\"m\", marg, dims=(\"oid\", \"m_nm\"))\n", | |
" \n", | |
" # 1. Define marginals and Evidence against observed\n", | |
" # Here we need to establish a LogNormal dist that we subsequently \n", | |
" # use to transform the observations. So we have to \"evidence\" the \n", | |
" # created dist object using a Potential (and minimise logp). \n", | |
" m_d = pm.LogNormal.dist(mu=0, sigma=1, shape=(N, 2))\n", | |
" _ = pm.Potential(\"pot_mhat\", pm.logp(m_d, m), dims=(\"oid\", \"mhat_nm\"))\n", | |
"\n", | |
" # 2. Transformation path pt1: Observed -> Uniform via Marginal CDF\n", | |
" u = pm.Deterministic(\"u\", pt.exp(pm.logcdf(m_d, m)), dims=(\"oid\", \"u_nm\"))\n", | |
"\n", | |
" # 3. Transformation path pt2: Uniform -> Normal via Normal InvCDF\n", | |
" n_d = pm.Normal.dist(mu=0., sigma=1., shape=(N, 2))\n", | |
" c = pm.Deterministic(\"c\", pm.icdf(n_d, u), dims=(\"oid\", \"c_nm\"))\n", | |
"\n", | |
" # 4. Create Latent Copula dist using a 2D MvNormal\n", | |
" sd = pm.InverseGamma.dist(alpha=5.0, beta=4.0)\n", | |
" chol, corr_, stds_ = pm.LKJCholeskyCov(\"lkjcc\", n=2, eta=2, sd_dist=sd)\n", | |
" c_d = pm.MvNormal.dist(mu=pt.zeros(2), chol=chol, shape=(N, 2))\n", | |
"\n", | |
" # 5. Evidence transformed C against Latent Copula using Potential\n", | |
" # because pymc TypeError: Variables that depend on other nodes\n", | |
" # cannot be used for observed data (c)\n", | |
" _ = pm.Potential(\"pot_chat\", pm.logp(c_d, c, warn_rvs=True), dims=(\"oid\", \"chat_nm\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "NameError", | |
"evalue": "name 'mdla' is not defined", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m display(pm\u001b[38;5;241m.\u001b[39mmodel_to_graphviz(\u001b[43mmdla\u001b[49m, formatting\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplain\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 2\u001b[0m display(\u001b[38;5;28mdict\u001b[39m(unobserved\u001b[38;5;241m=\u001b[39mmdla\u001b[38;5;241m.\u001b[39munobserved_RVs, observed\u001b[38;5;241m=\u001b[39mmdla\u001b[38;5;241m.\u001b[39mobserved_RVs))\n\u001b[1;32m 3\u001b[0m assert_no_rvs(mdla\u001b[38;5;241m.\u001b[39mlogp())\n", | |
"\u001b[0;31mNameError\u001b[0m: name 'mdla' is not defined" | |
] | |
} | |
], | |
"source": [ | |
"display(pm.model_to_graphviz(mdla, formatting=\"plain\"))\n", | |
"display(dict(unobserved=mdla.unobserved_RVs, observed=mdla.observed_RVs))\n", | |
"assert_no_rvs(mdla.logp())\n", | |
"mdla.debug(fn=\"random\", verbose=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Issue not seen in `ModelB` (included so you can see the same architecture working)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with pm.Model(coords=COORDS) as mdlb:\n", | |
" # NOTE: exactly the same as mdla except for section [#4] below:\n", | |
" \n", | |
" # 0. Create (Mutable)Data containers for obs\n", | |
" m = pm.Data(\"m\", marg, dims=(\"oid\", \"m_nm\"))\n", | |
" \n", | |
" # 1. Define marginals and Evidence against observed\n", | |
" # Here we need to establish a LogNormal dist that we subsequently \n", | |
" # use to transform the observations. So we have to \"evidence\" the \n", | |
" # created dist object using a Potential (and minimise logp). \n", | |
" m_d = pm.LogNormal.dist(mu=0, sigma=1, shape=(N, 2))\n", | |
" _ = pm.Potential(\"pot_mhat\", pm.logp(m_d, m), dims=(\"oid\", \"mhat_nm\"))\n", | |
"\n", | |
" # 2. Transformation path pt1: Observed -> Uniform via Marginal CDF\n", | |
" u = pm.Deterministic(\"u\", pt.exp(pm.logcdf(m_d, m)), dims=(\"oid\", \"u_nm\"))\n", | |
"\n", | |
" # 3. Transformation path pt2: Uniform -> Normal via Normal InvCDF\n", | |
" n_d = pm.Normal.dist(mu=0., sigma=1., shape=(N, 2))\n", | |
" c = pm.Deterministic(\"c\", pm.icdf(n_d, u), dims=(\"oid\", \"c_nm\"))\n", | |
"\n", | |
" # 4. Create Latent Copula dist using a 2D Normal\n", | |
" sd = pm.InverseGamma.dist(alpha=5.0, beta=4.0)\n", | |
" chol, corr_, stds_ = pm.LKJCholeskyCov(\"lkjcc\", n=2, eta=2, sd_dist=sd)\n", | |
" c_d = pm.Normal.dist(mu=pt.zeros(2), sigma=pt.diag(corr_[::-1]), shape=(N, 2))\n", | |
" \n", | |
" # 5. Evidence transformed C against Latent Copula using Potential\n", | |
" # because pymc TypeError: Variables that depend on other nodes\n", | |
" # cannot be used for observed data (c)\n", | |
" _ = pm.Potential(\"pot_chat\", pm.logp(c_d, c, warn_rvs=True), dims=(\"oid\", \"chat_nm\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 12.0.0 (0)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"531pt\" height=\"501pt\"\n", | |
" viewBox=\"0.00 0.00 531.00 501.03\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 497.03)\">\n", | |
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-497.03 527,-497.03 527,4 -4,4\"/>\n", | |
"<g id=\"clust1\" class=\"cluster\">\n", | |
"<title>clusteroid (50) x m_nm (2)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M105,-387.03C105,-387.03 210,-387.03 210,-387.03 216,-387.03 222,-393.03 222,-399.03 222,-399.03 222,-473.03 222,-473.03 222,-479.03 216,-485.03 210,-485.03 210,-485.03 105,-485.03 105,-485.03 99,-485.03 93,-479.03 93,-473.03 93,-473.03 93,-399.03 93,-399.03 93,-393.03 99,-387.03 105,-387.03\"/>\n", | |
"<text text-anchor=\"middle\" x=\"157.38\" y=\"-394.23\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x m_nm (2)</text>\n", | |
"</g>\n", | |
"<g id=\"clust2\" class=\"cluster\">\n", | |
"<title>clusteroid (50) x mhat_nm (2)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M20,-250.52C20,-250.52 141,-250.52 141,-250.52 147,-250.52 153,-256.52 153,-262.52 153,-262.52 153,-367.03 153,-367.03 153,-373.03 147,-379.03 141,-379.03 141,-379.03 20,-379.03 20,-379.03 14,-379.03 8,-373.03 8,-367.03 8,-367.03 8,-262.52 8,-262.52 8,-256.52 14,-250.52 20,-250.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"80.12\" y=\"-257.72\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x mhat_nm (2)</text>\n", | |
"</g>\n", | |
"<g id=\"clust3\" class=\"cluster\">\n", | |
"<title>clusteroid (50) x u_nm (2)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M173,-265.78C173,-265.78 273,-265.78 273,-265.78 279,-265.78 285,-271.78 285,-277.78 285,-277.78 285,-351.78 285,-351.78 285,-357.78 279,-363.78 273,-363.78 273,-363.78 173,-363.78 173,-363.78 167,-363.78 161,-357.78 161,-351.78 161,-351.78 161,-277.78 161,-277.78 161,-271.78 167,-265.78 173,-265.78\"/>\n", | |
"<text text-anchor=\"middle\" x=\"222.62\" y=\"-272.98\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x u_nm (2)</text>\n", | |
"</g>\n", | |
"<g id=\"clust4\" class=\"cluster\">\n", | |
"<title>clusteroid (50) x c_nm (2)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M155,-144.52C155,-144.52 255,-144.52 255,-144.52 261,-144.52 267,-150.52 267,-156.52 267,-156.52 267,-230.52 267,-230.52 267,-236.52 261,-242.52 255,-242.52 255,-242.52 155,-242.52 155,-242.52 149,-242.52 143,-236.52 143,-230.52 143,-230.52 143,-156.52 143,-156.52 143,-150.52 149,-144.52 155,-144.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"205\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x c_nm (2)</text>\n", | |
"</g>\n", | |
"<g id=\"clust5\" class=\"cluster\">\n", | |
"<title>cluster3</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M343,-253.87C343,-253.87 503,-253.87 503,-253.87 509,-253.87 515,-259.87 515,-265.87 515,-265.87 515,-363.68 515,-363.68 515,-369.68 509,-375.68 503,-375.68 503,-375.68 343,-375.68 343,-375.68 337,-375.68 331,-369.68 331,-363.68 331,-363.68 331,-265.87 331,-265.87 331,-259.87 337,-253.87 343,-253.87\"/>\n", | |
"<text text-anchor=\"middle\" x=\"503.62\" y=\"-261.07\" font-family=\"Times,serif\" font-size=\"14.00\">3</text>\n", | |
"</g>\n", | |
"<g id=\"clust6\" class=\"cluster\">\n", | |
"<title>cluster2 x 2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M287,-144.52C287,-144.52 369,-144.52 369,-144.52 375,-144.52 381,-150.52 381,-156.52 381,-156.52 381,-230.52 381,-230.52 381,-236.52 375,-242.52 369,-242.52 369,-242.52 287,-242.52 287,-242.52 281,-242.52 275,-236.52 275,-230.52 275,-230.52 275,-156.52 275,-156.52 275,-150.52 281,-144.52 287,-144.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"359.12\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">2 x 2</text>\n", | |
"</g>\n", | |
"<g id=\"clust7\" class=\"cluster\">\n", | |
"<title>cluster2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M401,-144.52C401,-144.52 483,-144.52 483,-144.52 489,-144.52 495,-150.52 495,-156.52 495,-156.52 495,-230.52 495,-230.52 495,-236.52 489,-242.52 483,-242.52 483,-242.52 401,-242.52 401,-242.52 395,-242.52 389,-236.52 389,-230.52 389,-230.52 389,-156.52 389,-156.52 389,-150.52 395,-144.52 401,-144.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"483.62\" y=\"-151.72\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n", | |
"</g>\n", | |
"<g id=\"clust8\" class=\"cluster\">\n", | |
"<title>clusteroid (50) x chat_nm (2)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M213,-8C213,-8 329,-8 329,-8 335,-8 341,-14 341,-20 341,-20 341,-124.52 341,-124.52 341,-130.52 335,-136.52 329,-136.52 329,-136.52 213,-136.52 213,-136.52 207,-136.52 201,-130.52 201,-124.52 201,-124.52 201,-20 201,-20 201,-14 207,-8 213,-8\"/>\n", | |
"<text text-anchor=\"middle\" x=\"270.75\" y=\"-15.2\" font-family=\"Times,serif\" font-size=\"14.00\">oid (50) x chat_nm (2)</text>\n", | |
"</g>\n", | |
"<!-- m -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>m</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M168,-477.03C168,-477.03 138,-477.03 138,-477.03 132,-477.03 126,-471.03 126,-465.03 126,-465.03 126,-431.53 126,-431.53 126,-425.53 132,-419.53 138,-419.53 138,-419.53 168,-419.53 168,-419.53 174,-419.53 180,-425.53 180,-431.53 180,-431.53 180,-465.03 180,-465.03 180,-471.03 174,-477.03 168,-477.03\"/>\n", | |
"<text text-anchor=\"middle\" x=\"153\" y=\"-459.73\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n", | |
"<text text-anchor=\"middle\" x=\"153\" y=\"-443.23\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"153\" y=\"-426.73\" font-family=\"Times,serif\" font-size=\"14.00\">Data</text>\n", | |
"</g>\n", | |
"<!-- pot_mhat -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>pot_mhat</title>\n", | |
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"144.85,-308.8 144.85,-345.25 114.48,-371.03 71.52,-371.03 41.15,-345.25 41.15,-308.8 71.52,-283.02 114.48,-283.02 144.85,-308.8\"/>\n", | |
"<text text-anchor=\"middle\" x=\"93\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">pot_mhat</text>\n", | |
"<text text-anchor=\"middle\" x=\"93\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"93\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">Potential</text>\n", | |
"</g>\n", | |
"<!-- m->pot_mhat -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>m->pot_mhat</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M138.93,-419.32C133.19,-407.91 126.35,-394.32 119.73,-381.16\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"123.03,-379.93 115.41,-372.57 116.78,-383.08 123.03,-379.93\"/>\n", | |
"</g>\n", | |
"<!-- u -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>u</title>\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"259.12,-355.78 168.88,-355.78 168.88,-298.28 259.12,-298.28 259.12,-355.78\"/>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">u</text>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n", | |
"</g>\n", | |
"<!-- m->u -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>m->u</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M167.3,-419.32C175.43,-403.44 185.71,-383.33 194.56,-366.03\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"197.45,-368.07 198.89,-357.57 191.22,-364.88 197.45,-368.07\"/>\n", | |
"</g>\n", | |
"<!-- c -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>c</title>\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"259.12,-234.52 168.88,-234.52 168.88,-177.02 259.12,-177.02 259.12,-234.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"214\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n", | |
"</g>\n", | |
"<!-- u->c -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>u->c</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M214,-298.06C214,-282.63 214,-263.2 214,-246.24\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"217.5,-246.48 214,-236.48 210.5,-246.48 217.5,-246.48\"/>\n", | |
"</g>\n", | |
"<!-- pot_chat -->\n", | |
"<g id=\"node8\" class=\"node\">\n", | |
"<title>pot_chat</title>\n", | |
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"319.98,-66.28 319.98,-102.74 291.29,-128.52 250.71,-128.52 222.02,-102.74 222.02,-66.28 250.71,-40.5 291.29,-40.5 319.98,-66.28\"/>\n", | |
"<text text-anchor=\"middle\" x=\"271\" y=\"-95.96\" font-family=\"Times,serif\" font-size=\"14.00\">pot_chat</text>\n", | |
"<text text-anchor=\"middle\" x=\"271\" y=\"-79.46\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"271\" y=\"-62.96\" font-family=\"Times,serif\" font-size=\"14.00\">Potential</text>\n", | |
"</g>\n", | |
"<!-- c->pot_chat -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>c->pot_chat</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M227.36,-176.81C232.76,-165.51 239.19,-152.07 245.42,-139.03\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"248.55,-140.6 249.7,-130.07 242.23,-137.58 248.55,-140.6\"/>\n", | |
"</g>\n", | |
"<!-- lkjcc -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>lkjcc</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"423\" cy=\"-327.03\" rx=\"84.5\" ry=\"40.66\"/>\n", | |
"<text text-anchor=\"middle\" x=\"423\" y=\"-338.48\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc</text>\n", | |
"<text text-anchor=\"middle\" x=\"423\" y=\"-321.98\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"423\" y=\"-305.48\" font-family=\"Times,serif\" font-size=\"14.00\">_LKJCholeskyCov</text>\n", | |
"</g>\n", | |
"<!-- lkjcc_corr -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>lkjcc_corr</title>\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"373.12,-234.52 282.88,-234.52 282.88,-177.02 373.12,-177.02 373.12,-234.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"328\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc_corr</text>\n", | |
"<text text-anchor=\"middle\" x=\"328\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"328\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n", | |
"</g>\n", | |
"<!-- lkjcc->lkjcc_corr -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>lkjcc->lkjcc_corr</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M393.2,-288.62C381.76,-274.25 368.71,-257.87 357.31,-243.56\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"360.37,-241.78 351.4,-236.14 354.89,-246.14 360.37,-241.78\"/>\n", | |
"</g>\n", | |
"<!-- lkjcc_stds -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>lkjcc_stds</title>\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"487.12,-234.52 396.88,-234.52 396.88,-177.02 487.12,-177.02 487.12,-234.52\"/>\n", | |
"<text text-anchor=\"middle\" x=\"442\" y=\"-217.22\" font-family=\"Times,serif\" font-size=\"14.00\">lkjcc_stds</text>\n", | |
"<text text-anchor=\"middle\" x=\"442\" y=\"-200.72\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"442\" y=\"-184.22\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n", | |
"</g>\n", | |
"<!-- lkjcc->lkjcc_stds -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>lkjcc->lkjcc_stds</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M429.33,-286.31C431.38,-273.45 433.65,-259.19 435.7,-246.33\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"439.15,-246.92 437.26,-236.5 432.23,-245.82 439.15,-246.92\"/>\n", | |
"</g>\n", | |
"<!-- lkjcc_corr->pot_chat -->\n", | |
"<g id=\"edge7\" class=\"edge\">\n", | |
"<title>lkjcc_corr->pot_chat</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M314.64,-176.81C309.24,-165.51 302.81,-152.07 296.58,-139.03\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"299.77,-137.58 292.3,-130.07 293.45,-140.6 299.77,-137.58\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x14b308990>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"{'unobserved': [lkjcc ~ _lkjcholeskycov(2, 2, InverseGamma(5, 4)),\n", | |
" u ~ Deterministic(f()),\n", | |
" c ~ Deterministic(f()),\n", | |
" lkjcc_corr ~ Deterministic(f(lkjcc)),\n", | |
" lkjcc_stds ~ Deterministic(f(lkjcc))],\n", | |
" 'observed': []}" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"point={'lkjcc_cholesky-cov-packed__': array([0., 0., 0.])}\n", | |
"\n", | |
"No problems found\n" | |
] | |
} | |
], | |
"source": [ | |
"display(pm.model_to_graphviz(mdlb, formatting=\"plain\"))\n", | |
"display(dict(unobserved=mdlb.unobserved_RVs, observed=mdlb.observed_RVs))\n", | |
"assert_no_rvs(mdlb.logp())\n", | |
"mdlb.debug(fn=\"random\", verbose=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The watermark extension is already loaded. To reload it, use:\n", | |
" %reload_ext watermark\n", | |
"Last updated: Tue Dec 03 2024\n", | |
"\n", | |
"Python implementation: CPython\n", | |
"Python version : 3.11.10\n", | |
"IPython version : 8.29.0\n", | |
"\n", | |
"pymc : 5.16.2\n", | |
"numpy : 1.26.4\n", | |
"pytensor: 2.25.5\n", | |
"\n", | |
"Watermark: 2.5.0\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%load_ext watermark\n", | |
"%watermark -n -u -v -iv -w" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "oreum_copula", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.10" | |
}, | |
"widgets": { | |
"state": { | |
"2478528dd27c4eed9219f055bf907997": { | |
"views": [ | |
{ | |
"cell_index": 67 | |
} | |
] | |
}, | |
"c789864f037443ada1d6ad40c44524f4": { | |
"views": [ | |
{ | |
"cell_index": 68 | |
} | |
] | |
} | |
}, | |
"version": "1.2.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment