Created
April 26, 2020 16:44
-
-
Save gwgundersen/087da1ac4e2bad5daf8192b4d8f6a3cf to your computer and use it in GitHub Desktop.
Visualizing a multivariate Gaussian distribution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Because I always forget how to do this. | |
# | |
# Credit: https://scipython.com/blog/visualizing-the-bivariate-gaussian-distribution/ | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from mpl_toolkits.mplot3d import Axes3D | |
from scipy.stats import multivariate_normal | |
# Our 2-dimensional distribution will be over variables X and Y | |
N = 60 | |
X = np.linspace(-3, 3, N) | |
Y = np.linspace(-3, 4, N) | |
X, Y = np.meshgrid(X, Y) | |
# Mean vector and covariance matrix | |
mu = np.array([0., 1.]) | |
Sigma = np.array([[ 1. , -0.5], [-0.5, 1.5]]) | |
# Pack X and Y into a single 3-dimensional array | |
pos = np.empty(X.shape + (2,)) | |
pos[:, :, 0] = X | |
pos[:, :, 1] = Y | |
# The distribution on the variables X, Y packed into pos. | |
F = multivariate_normal(mu, Sigma) | |
Z = F.pdf(pos) | |
# Create a surface plot and projected filled contour plot under it. | |
fig = plt.figure() | |
ax = fig.gca(projection='3d') | |
ax.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True, | |
cmap=cm.viridis) | |
cset = ax.contourf(X, Y, Z, zdir='z', offset=-0.15, cmap=cm.viridis) | |
# Adjust the limits, ticks and view angle | |
ax.set_zlim(-0.15,0.2) | |
ax.set_zticks(np.linspace(0,0.2,5)) | |
ax.view_init(27, -21) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment