Created
November 3, 2017 00:21
-
-
Save sbarratt/42da6a307e27c15ae00e783cd13cd36e to your computer and use it in GitHub Desktop.
K-means script that works with NaN entries.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Author: Shane Barratt | |
Email: [email protected] | |
K-means script that works with NaN entries. | |
""" | |
import numpy as np | |
import IPython as ipy | |
import matplotlib.pyplot as plt | |
def kmeans(data, k, max_iterations): | |
# initialize centroids randomly | |
num_features = data.shape[1] | |
centroids = _get_random_centroids(num_features, k, data) | |
# initialize bookkeeping variables | |
iterations = 0 | |
old_centroids = None | |
# run the k-means algorithm | |
while not _should_stop(old_centroids, centroids, iterations, max_iterations): | |
old_centroids = centroids | |
iterations += 1 | |
# Assign labels to each datapoint based on centroids | |
labels = _get_labels(data, centroids) | |
# Assign centroids based on datapoint labels | |
centroids = _get_centroids(data, labels, k) | |
return centroids, _get_labels(data, centroids) | |
def _distance(x, centroids): | |
return np.sqrt(np.nansum((x - centroids)**2, axis=1)) | |
def _mean(data): | |
return np.nan_to_num(np.nanmean(data, axis=0)) | |
def _get_random_centroids(num_features, k, data): | |
centroids = np.random.normal(size=(k, num_features)) | |
return centroids | |
def _should_stop(old_centroids, centroids, iterations, max_iterations): | |
# Stop if centroids haven't changed or if a certain number of iterations have passed | |
if iterations > max_iterations: return True | |
if old_centroids is None: | |
return False | |
return np.all(np.equal(old_centroids, centroids)) | |
def _get_labels(data, centroids): | |
# For each element in the data, chose the closest centroid. | |
# Make that centroid the element's label. | |
N = data.shape[0] | |
labels = np.zeros(N) | |
for i in range(N): | |
x = data[i, :] | |
dist_to_centroids = _distance(x, centroids) | |
labels[i] = np.argmin(dist_to_centroids) | |
return labels | |
def _get_centroids(data, labels, k): | |
# Each centroid is the geometric mean of the points that | |
# have that centroid's label. | |
centroids = np.zeros((k, data.shape[1])) | |
for j in range(k): | |
centroid_data = data[labels == j, :] | |
if centroid_data.shape[0] == 0: # randomly re-intialize cluster | |
centroids[j, :] = data[np.random.choice(np.arange(k)), :] | |
else: | |
centroids[j, :] = _mean(centroid_data) | |
return centroids | |
if __name__ == '__main__': | |
# Data is mixture of gaussians at [1, 1] and [-1, -1] | |
data = np.random.multivariate_normal(np.ones(2), np.eye(2), size=1000) | |
data = np.r_[data, np.random.multivariate_normal(-np.ones(2), np.eye(2), size=1000)] | |
# Randomly nan out entries of the data | |
for i in range(data.shape[0]): | |
for j in range(data.shape[1]): | |
if np.random.random() < .2: | |
data[i, j] *= np.nan | |
# Run k-means and plot centroids | |
centroids, labels = kmeans(data, 2, 200) | |
plt.scatter(centroids[:,0], centroids[:, 1], c='r') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment