Skip to content

Instantly share code, notes, and snippets.

@al6x
Created January 11, 2025 09:20
Show Gist options
  • Save al6x/7808b1d0cd936689f361f5dcf5e3a751 to your computer and use it in GitHub Desktop.
Save al6x/7808b1d0cd936689f361f5dcf5e3a751 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
def fit_normal_mixture(*, n_components, values, random_state, n_init):
values = np.array(values).reshape(-1, 1) # Convert to 2D array
nmm = GaussianMixture(n_components, covariance_type='diag', random_state=random_state, n_init=n_init)
nmm.fit(values)
means = nmm.means_.flatten().tolist()
sigmas = np.sqrt(nmm.covariances_.flatten()).tolist()
weights = nmm.weights_.flatten().tolist()
return weights, means, sigmas
def sample_normal_mixture(*, weights, means, sigmas, n):
if not np.isclose(sum(weights), 1):
raise ValueError("Weights must sum to 1")
components = np.random.choice(len(weights), size=n, p=weights)
return np.random.normal(loc=np.array(means)[components], scale=np.array(sigmas)[components])
# Plotting function
def plot_mixture(weights, means, sigmas, label, color, ax):
x = np.linspace(-10, 10, 1000)
y = np.zeros_like(x)
for weight, mean, sigma in zip(weights, means, sigmas):
y += weight * norm.pdf(x, loc=mean, scale=sigma)
ax.plot(x, y, label=label, color=color)
# Test
if __name__ == '__main__':
# Generating sample from Normal Mixture Model
original_weights = [0.5, 0.5]
original_means = [0, 0]
original_sigmas = [1, 2]
n_samples = 20000
nmm_sample = sample_normal_mixture(weights=original_weights, means=original_means, sigmas=original_sigmas, n=n_samples)
# Fitting sample to Normal Mixture Model
fitted_weights, fitted_means, fitted_sigmas = fit_normal_mixture(
n_components=2, values=nmm_sample, random_state=0, n_init=10
)
print("Original parameters:")
print(f"Weights: {original_weights}, Means: {original_means}, Sigmas: {original_sigmas}")
print("Fitted parameters:")
print(f"Weights: {fitted_weights}, Means: {fitted_means}, Sigmas: {fitted_sigmas}")
# Plotting
fig, ax = plt.subplots(figsize=(10, 6))
# Plot original model
plot_mixture(original_weights, original_means, original_sigmas, label="Original Model", color="blue", ax=ax)
# Plot fitted model
plot_mixture(fitted_weights, fitted_means, fitted_sigmas, label="Fitted Model", color="red", ax=ax)
# Histogram of samples
ax.hist(nmm_sample, bins=100, density=True, alpha=0.5, color='gray', label='Sample Histogram')
ax.set_title("Original and Fitted Normal Mixture Models")
ax.set_xlabel("Value")
ax.set_ylabel("Density")
ax.legend()
plt.show()
@al6x
Copy link
Author

al6x commented Jan 11, 2025

And Julia

using Distributions
using GaussianMixtures

nmm = MixtureModel(Normal[
  Normal(0.0, 1.0),
  Normal(0.0, 2.0)
],
  [0.5, 0.5]
)

sample = rand(nmm, 1000)

m = GMM(2, sample; method=:kmeans)
print(m)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment