Created
January 11, 2025 09:20
-
-
Save al6x/7808b1d0cd936689f361f5dcf5e3a751 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.mixture import GaussianMixture | |
from scipy.stats import norm | |
def fit_normal_mixture(*, n_components, values, random_state, n_init): | |
values = np.array(values).reshape(-1, 1) # Convert to 2D array | |
nmm = GaussianMixture(n_components, covariance_type='diag', random_state=random_state, n_init=n_init) | |
nmm.fit(values) | |
means = nmm.means_.flatten().tolist() | |
sigmas = np.sqrt(nmm.covariances_.flatten()).tolist() | |
weights = nmm.weights_.flatten().tolist() | |
return weights, means, sigmas | |
def sample_normal_mixture(*, weights, means, sigmas, n): | |
if not np.isclose(sum(weights), 1): | |
raise ValueError("Weights must sum to 1") | |
components = np.random.choice(len(weights), size=n, p=weights) | |
return np.random.normal(loc=np.array(means)[components], scale=np.array(sigmas)[components]) | |
# Plotting function | |
def plot_mixture(weights, means, sigmas, label, color, ax): | |
x = np.linspace(-10, 10, 1000) | |
y = np.zeros_like(x) | |
for weight, mean, sigma in zip(weights, means, sigmas): | |
y += weight * norm.pdf(x, loc=mean, scale=sigma) | |
ax.plot(x, y, label=label, color=color) | |
# Test | |
if __name__ == '__main__': | |
# Generating sample from Normal Mixture Model | |
original_weights = [0.5, 0.5] | |
original_means = [0, 0] | |
original_sigmas = [1, 2] | |
n_samples = 20000 | |
nmm_sample = sample_normal_mixture(weights=original_weights, means=original_means, sigmas=original_sigmas, n=n_samples) | |
# Fitting sample to Normal Mixture Model | |
fitted_weights, fitted_means, fitted_sigmas = fit_normal_mixture( | |
n_components=2, values=nmm_sample, random_state=0, n_init=10 | |
) | |
print("Original parameters:") | |
print(f"Weights: {original_weights}, Means: {original_means}, Sigmas: {original_sigmas}") | |
print("Fitted parameters:") | |
print(f"Weights: {fitted_weights}, Means: {fitted_means}, Sigmas: {fitted_sigmas}") | |
# Plotting | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
# Plot original model | |
plot_mixture(original_weights, original_means, original_sigmas, label="Original Model", color="blue", ax=ax) | |
# Plot fitted model | |
plot_mixture(fitted_weights, fitted_means, fitted_sigmas, label="Fitted Model", color="red", ax=ax) | |
# Histogram of samples | |
ax.hist(nmm_sample, bins=100, density=True, alpha=0.5, color='gray', label='Sample Histogram') | |
ax.set_title("Original and Fitted Normal Mixture Models") | |
ax.set_xlabel("Value") | |
ax.set_ylabel("Density") | |
ax.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
And Julia