Created
May 2, 2013 23:35
-
-
Save peterk87/5506262 to your computer and use it in GitHub Desktop.
Python: Hierarchical clustering plot and number of clusters over distances plot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.spatial.distance import * | |
from scipy.cluster.hierarchy import * | |
import pandas as pd | |
import numpy | |
import matplotlib as plt | |
from matplotlib.pylab import figure | |
import pylab as pl | |
import pp | |
def num_clusters(hc, d): | |
""" | |
Get the number of flat clusters from a linkage matrix at a specified | |
distance. | |
Args: | |
hc (scipy.cluster.hierarchy.linkage): Linkage matrix. | |
d (number): Distance threshold for defining flat clusters. | |
Returns: | |
Number of unique flat clusters produced when flat clusters defined at | |
the specified distance threshold. | |
""" | |
return len(numpy.unique(scipy.cluster.hierarchy.fcluster(hc, d, criterion='distance'))) | |
def pp_run_num_clusters(hc, distances): | |
""" | |
Parallel python (pp) is used to run the num_clusters function in parallel | |
since it can be slow on a single core. | |
Args: | |
hc (scipy.cluster.hierarchy.linkage): Linkage matrix. | |
d (number): Distance threshold for defining flat clusters. | |
Returns: | |
A vector of the number of unique flat clusters from the linkage matrix | |
at all of the distances specified. | |
""" | |
job_server = pp.Server() | |
jobs = [job_server.submit(num_clusters, (hc, d,), modules=('numpy','scipy.cluster.hierarchy',)) for d in distances] | |
job_server.print_stats() | |
return [job() for job in jobs] | |
def plot_dendrogram_num_clusters(df, dist_metric, linkage_method, threshold): | |
""" | |
Plot a dendrogram with clusters defined at a specified distance threshold | |
and plot a line graph showing the number of clusters at all distances. | |
Args: | |
df (pandas.DataFrame): A dataframe of numeric values to be clustered. | |
dist_metric (string): Distance metric for pdist function. | |
linkage_method (string): Linkage method for linkage function. | |
threshold (number): Distance threshold for defining clusters. | |
Returns: | |
Plots of the hierarchical clustering and number of clusters over | |
distance. | |
""" | |
dm = pdist(df, metric=dist_metric) | |
print 'Distance matrix computed. Length:', len(dm) | |
hc = linkage(dm, method=linkage_method) | |
print 'Hierarchical clustering completed.' | |
distances = numpy.unique(dm) | |
xs = pp_run_num_clusters(hc, distances) | |
ys = [y for y in distances] | |
dendrogram(hc, | |
leaf_label_func=lambda x: df.index[x], | |
color_threshold=threshold) | |
f = pl.gcf() | |
f.get_axes()[0].axhline(y=threshold, linestyle='--', color='red') | |
f.autofmt_xdate() | |
f.set_size_inches(16, 6) | |
fig = pl.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(xs, ys) | |
ax.axhline(y=threshold, linestyle='--', color='red') | |
fig.set_size_inches(16, 6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment