Skip to content

Instantly share code, notes, and snippets.

@tehZevo
Created February 9, 2021 18:32
Show Gist options
  • Save tehZevo/9605f25c2af486b51ca6af4ad5575736 to your computer and use it in GitHub Desktop.
Save tehZevo/9605f25c2af486b51ca6af4ad5575736 to your computer and use it in GitHub Desktop.
Train an autoencoder on MNIST digits, but with a softmax latent vector. Then plot the means of digits by argmax(latent)
from keras.layers import Dense, Flatten, InputLayer, Reshape
from keras.models import Sequential
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x = np.concatenate([x_train, x_test]) / 255.
activation = "swish"
latent_size = 10
encoder = Sequential([
Flatten(input_shape=x.shape[1:]),
Dense(64, activation=activation),
Dense(64, activation=activation),
Dense(latent_size, activation="softmax")
])
decoder = Sequential([
Dense(64, activation=activation, input_shape=[latent_size]),
Dense(64, activation=activation),
Dense(np.prod(x.shape[1:])),
Reshape(x.shape[1:])
])
model = Sequential([encoder, decoder])
model.compile(optimizer="adam", loss="mse")
encoder.summary()
decoder.summary()
model.summary()
model.fit(x, x, epochs=2)
encodings = encoder.predict(x)
#calculate argmax of softmax encodings for each image
argmaxes = np.argmax(encodings, axis=-1)
print(argmaxes.shape)
plt.figure(figsize=(2 * latent_size, 2))
for i in range(latent_size):
#subplot for each latent
plt.subplot(1, latent_size, i+1)
#gather samples that have an argmax of this latent
xx = [xxx for xxx, a in zip(x, argmaxes) if a == i]
#mean of images
xx = np.mean(xx, axis=0)
print(xx.shape)
plt.title("latent[{}]".format(i))
#plot mean
plt.imshow(xx)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment