Last active
December 6, 2018 12:54
-
-
Save vadimkantorov/1ff9ceb82c0d06365f0a21d44e870d98 to your computer and use it in GitHub Desktop.
Example of interactive SVG scatter plot (with image thumbnails) produced by running t-SNE on MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import random | |
import cv2 | |
import torch | |
import torchvision | |
def svg(points, labels, thumbnails, legend_size = 1e-1, legend_font_size = 5e-2, circle_radius = 5e-3): | |
points = (points - points.min(0)[0]) / (points.max(0)[0] - points.min(0)[0]) | |
class_index = sorted(set(labels)) | |
class_colors = [360.0 * i / len(class_index) for i in range(len(class_index))] | |
colors = [class_colors[class_index.index(label)] for label in labels] | |
thumbnails_base64 = [base64.b64encode(cv2.imencode('.jpg', img.mul(255).permute(1, 2, 0).numpy()[..., ::-1])[1]) for img in thumbnails] | |
return '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1">' + \ | |
''.join(map('''<circle cx="{}" cy="{}" title="{}" fill="hsl({}, 50%, 50%)" r="{}" desc="data:image/jpeg;base64,{}" onmouseover="evt.target.ownerDocument.getElementById('preview').setAttribute('href', evt.target.getAttribute('desc')); evt.target.ownerDocument.getElementById('label').textContent = evt.target.getAttribute('title');" />'''.format, points[:, 0], points[:, 1], labels, colors, [circle_radius] * len(points), thumbnails_base64)) + \ | |
'''<image id="preview" x="0" y="{legend_size}" width="{legend_size}" height="{legend_size}" /> | |
<text id="label" x="0" y="{legend_size}" font-size="{legend_font_size}" /> | |
</svg>'''.format(legend_size = legend_size, legend_font_size = legend_font_size) | |
#def scatter(points, labels, thumbnails = None, legend_size = 1e-1, legend_font_size = 5e-2, circle_radius = 5e-3, size = 512): | |
# points = (points - points.min(0)[0]) / (points.max(0)[0] - points.min(0)[0]) | |
# class_index = sorted(set(labels)) | |
# class_colors = [360.0 * i / len(class_index) for i in range(len(class_index))] | |
# colors = [class_colors[class_index.index(label)] for label in labels] | |
# thumbnails_base64 = [base64.b64encode(cv2.imencode('.jpg', np.asarray(img)[..., ::-1])[1]).decode('ascii') for img in thumbnails] if thumbnails is not None else [''] * len(points) | |
# return '<svg viewBox="0 0 1 1" width="{0}" height="{0}">'.format(size) + \ | |
# ''.join(map('''<circle cx="{}" cy="{}" title="{}" fill="hsl({}, 50%, 50%)" r="{}" desc="data:image/jpeg;base64,{}" onmouseover="evt.target.ownerDocument.getElementById('preview').setAttribute('href', evt.target.getAttribute('desc')); evt.target.ownerDocument.getElementById('label').textContent = evt.target.getAttribute('title');" />'''.format, points[:, 0], points[:, 1], labels, colors, [circle_radius] * len(points), thumbnails_base64)) + \ | |
# '''<image id="preview" x="0" y="{legend_size}" width="{legend_size}" height="{legend_size}" /> | |
# <text id="label" x="0" y="{legend_size}" font-size="{legend_font_size}" /> | |
# </svg>'''.format(legend_size = legend_size, legend_font_size = legend_font_size) | |
def tsne(featuers): | |
# TODO: reimplement following https://github.com/cemoody/topicsne/blob/master/tsne.py and https://nlml.github.io/in-raw-numpy/in-raw-numpy-t-sne/ | |
import sklearn.manifold | |
return torch.from_numpy(sklearn.manifold.TSNE().fit_transform(features.numpy())) | |
dataset = torchvision.datasets.MNIST('./MNIST', train = True, download = True, transform = torchvision.transforms.ToTensor()) | |
thumbnails, labels = zip(*dataset) | |
features = torch.cat(thumbnails).view(len(thumbnails), -1) | |
features, labels, thumbnails = zip(*random.sample(zip(features, labels, thumbnails), k = 10000)) | |
features = torch.stack(features) | |
embeddings = tsne(features) | |
open('tsne.svg', 'w').write(svg(embeddings, labels, thumbnails)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I tried running this code and the images are not displayed when the svg file is opened in the browser. I have attached a snapshot here.
