Created
March 19, 2017 13:39
-
-
Save posquit0/ba3a5d7f9251bad736c087a3276d6d09 to your computer and use it in GitHub Desktop.
k-menas with numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KMeans(object): | |
"""K-Means clustring | |
""" | |
def __init__(self, k, init='forgy', max_iter=50): | |
self.k = k | |
self.init = init | |
self.max_iter = max_iter | |
def _init_centers(self): | |
if self.init is 'forgy': | |
centers = random.sample(self.samples.tolist(), self.k) | |
else: | |
centers = np.empty((self.n_samples, self.n_features), dtype=self.samples.dtype) | |
self.centers = np.array(centers, dtype=self.samples.dtype) | |
def _step_assignment(self): | |
self.labels = np.empty(self.n_samples, dtype=int) | |
self.clusters = {} | |
for idx, sample in enumerate(self.samples): | |
distances = [np.linalg.norm(sample - center) for center in self.centers] | |
min_idx = np.argmin(distances) | |
self.labels.put(idx, min_idx) | |
try: | |
self.clusters[min_idx].append(sample) | |
except: | |
self.clusters[min_idx] = [sample] | |
def _step_update(self): | |
new_centers = [] | |
keys = sorted(self.clusters.keys()) | |
for key in keys: | |
new_center = np.mean(self.clusters[key], axis=0) | |
new_centers.append(new_center) | |
self.centers = np.array(new_centers, dtype=self.samples.dtype) | |
def _next(self): | |
self._step_assignment() | |
self._step_update() | |
def _is_converged(self): | |
prev_center_set = set([tuple(center) for center in self.prev_centers]) | |
center_set = set([tuple(center) for center in self.centers]) | |
return prev_center_set == center_set | |
def fit(self, x): | |
self.samples = x | |
self.n_samples, self.n_features = x.shape | |
# Choose initial centers | |
self._init_centers() | |
# Repeat E-M Optimization | |
for i in range(self.max_iter): | |
self.prev_centers = self.centers.copy() | |
self._next() | |
# Finish when the algorithm find the local maximum | |
if self._is_converged(): | |
break | |
def predict(self, x): | |
samples = x | |
labels = np.empty(len(x), dtype=int) | |
for idx, sample in enumerate(samples): | |
distances = [np.linalg.norm(sample - center) for center in self.centers] | |
min_idx = np.argmin(distances) | |
labels.put(idx, min_idx) | |
return labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment