Skip to content

Instantly share code, notes, and snippets.

@tdiggelm
Last active February 20, 2019 10:53
Show Gist options
  • Save tdiggelm/da7e4146dcbf1536ae75178768ac6ef2 to your computer and use it in GitHub Desktop.
Save tdiggelm/da7e4146dcbf1536ae75178768ac6ef2 to your computer and use it in GitHub Desktop.
categorical scatter plot with matplotlib
import numpy as np
import matplotlib.pyplot as plt
def catplot(x, y, c, labels=None, title=None, n_categories=None,
s=10, alpha=0.65, cmap='rainbow', fig=None, ax=None,
border={'color': '0.7', 'linewidth': 1}, facecolor='white'):
if not fig:
fig, ax = plt.subplots(1, 1)
try:
N = int(n_categories)
except:
N = np.unique(c).shape[0]
sp = ax.scatter(x, y, c=c, s=s, cmap=plt.cm.get_cmap(cmap, N), alpha=alpha)
cb = fig.colorbar(sp, ticks=range(N), ax=ax)
if labels:
cb.set_ticklabels(labels)
ax.set_aspect('auto')
sp.set_clim(-0.5, N - 0.5) # align labels
ax.set_xticks([])
ax.set_yticks([])
if facecolor:
ax.set_facecolor(facecolor)
if border:
for o in ['top', 'right', 'bottom', 'left']:
ax.spines[o].set_color(border.get('color', '0.7'))
ax.spines[o].set_linewidth(border.get('linewidth', 1))
ax.spines[o].set_visible(True)
if title:
ax.set_title(title)
return fig, ax
if __name__ == "__main__":
import matplotlib.pyplot as plt
from sklearn.manifold import Isomap, LocallyLinearEmbedding, TSNE
from sklearn.decomposition import TruncatedSVD
from sklearn.datasets import load_digits
digits = load_digits(n_class=6)
c = digits.target
labels = ['digit %d' % d for d in digits.target_names]
models = [
('t-SNE', TSNE(perplexity=10)),
('Isomap', Isomap()),
('PCA', TruncatedSVD(n_components=2)),
('LLE', LocallyLinearEmbedding(10, n_components=2, method='standard'))
]
fig, axes = plt.subplots(2, 2, figsize=(10,10))
axes_flat = [ax for row in axes for ax in row]
for (model_name, model), ax in zip(models, axes_flat):
projection = model.fit_transform(digits.data)
x = projection[:, 0]
y = projection[:, 1]
catplot(x, y, c, labels, '%s projection of %d MNIST digits' % (model_name, c.shape[0]), fig=fig, ax=ax)
plt.tight_layout()
plt.savefig('mnist-projections.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment