Skip to content

Instantly share code, notes, and snippets.

@steinelu
Created August 18, 2025 12:10
Show Gist options
  • Save steinelu/c83bf84cc0fb5b0ab53963911c583543 to your computer and use it in GitHub Desktop.
Save steinelu/c83bf84cc0fb5b0ab53963911c583543 to your computer and use it in GitHub Desktop.
Python function calculating the Gauss-Wasserstein distance
def wasserstein_distance_gaussian_analytical(G1, G2):
mu1, sigma1 = G1
mu2, sigma2 = G2
diff_mean = mu1 - mu2
squared_diff_mean = np.dot(diff_mean, diff_mean) # euclidean norm
sigma1_sqrt = linalg.sqrtm(sigma1)
sigma_product = np.dot(sigma1_sqrt, np.dot(sigma2, sigma1_sqrt))
sigma_product_sqrt = scipy.linalg.sqrtm(sigma_product)
trace_term = np.trace(sigma1 + sigma2 - 2 * sigma_product_sqrt)
wasserstein_distance_sq = squared_diff_mean + trace_term
wasserstein_distance = np.sqrt(wasserstein_distance_sq)
# not squared anymore
return wasserstein_distance
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment