Created
February 9, 2021 18:32
-
-
Save tehZevo/9605f25c2af486b51ca6af4ad5575736 to your computer and use it in GitHub Desktop.
Train an autoencoder on MNIST digits, but with a softmax latent vector. Then plot the means of digits by argmax(latent)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Dense, Flatten, InputLayer, Reshape | |
from keras.models import Sequential | |
from keras.datasets import mnist | |
import numpy as np | |
import matplotlib.pyplot as plt | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x = np.concatenate([x_train, x_test]) / 255. | |
activation = "swish" | |
latent_size = 10 | |
encoder = Sequential([ | |
Flatten(input_shape=x.shape[1:]), | |
Dense(64, activation=activation), | |
Dense(64, activation=activation), | |
Dense(latent_size, activation="softmax") | |
]) | |
decoder = Sequential([ | |
Dense(64, activation=activation, input_shape=[latent_size]), | |
Dense(64, activation=activation), | |
Dense(np.prod(x.shape[1:])), | |
Reshape(x.shape[1:]) | |
]) | |
model = Sequential([encoder, decoder]) | |
model.compile(optimizer="adam", loss="mse") | |
encoder.summary() | |
decoder.summary() | |
model.summary() | |
model.fit(x, x, epochs=2) | |
encodings = encoder.predict(x) | |
#calculate argmax of softmax encodings for each image | |
argmaxes = np.argmax(encodings, axis=-1) | |
print(argmaxes.shape) | |
plt.figure(figsize=(2 * latent_size, 2)) | |
for i in range(latent_size): | |
#subplot for each latent | |
plt.subplot(1, latent_size, i+1) | |
#gather samples that have an argmax of this latent | |
xx = [xxx for xxx, a in zip(x, argmaxes) if a == i] | |
#mean of images | |
xx = np.mean(xx, axis=0) | |
print(xx.shape) | |
plt.title("latent[{}]".format(i)) | |
#plot mean | |
plt.imshow(xx) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment