Created
March 31, 2018 20:44
-
-
Save vishnubob/469147549f71911cf51f69d896bd1bc9 to your computer and use it in GitHub Desktop.
Simple Autoencoder with useful PIL based image tiler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# MXNet Autoencoder | |
# based on example from SherlockLiao | |
# https://github.com/SherlockLiao/mxnet-gluon-tutorial/blob/master/08-AutoEncoder/simple_autoencoder.py | |
import operator | |
import bisect | |
import os | |
import numpy as np | |
import mxnet as mx | |
from mxnet import gluon | |
from PIL import Image | |
def gcd(a,b): | |
while b > 0: | |
a, b = b, a % b | |
return a | |
def lcm(a, b): | |
return a * b / gcd(a, b) | |
def find_res(cnt, ratio=.5): | |
# given a count of images, find an arrangement of width and height | |
# that is as close to .5 in ratio as possible | |
vals = set([gcd(cnt, x) for x in range(1, cnt)]) | |
rats = [] | |
for width in vals: | |
height = cnt / width | |
rat = width / height | |
rats.append((rat, width, height)) | |
rats = sorted(rats, key=operator.itemgetter(0)) | |
keys = [rat[0] for rat in rats] | |
idx = bisect.bisect_left(keys, ratio) | |
idx = min(len(rats) - 1, idx) | |
width = max(rats[idx][1:]) | |
height = min(rats[idx][1:]) | |
return list(map(int, (width, height))) | |
def norm_ip(img, min, max): | |
img = np.clip(img, min, max) | |
img = (img - min) / (max - min + 1e-5) | |
return img | |
def norm_range(t, range=None): | |
if range is not None: | |
return norm_ip(t, range[0], range[1]) | |
else: | |
return norm_ip(t, t.min(), t.max()) | |
def tile_image(imgs): | |
# based on code found at | |
# https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format | |
imgcnt = imgs.shape[0] | |
(prow, pcol) = find_res(imgcnt) | |
sqr = int(round(imgs.shape[1] ** .5)) | |
pad = prow - imgs.shape[0] % pcol | |
imgs = imgs.reshape(imgs.shape[0], sqr, sqr) | |
tiled = [] | |
total = prow * pcol | |
for i in range(0, total, prow): | |
tiled.append(np.hstack(imgs[i:i+prow,:,:])) | |
return np.vstack(tiled) | |
if not os.path.exists('./mlp_img'): | |
os.mkdir('./mlp_img') | |
def to_img(x): | |
x = 0.5 * (x + 1) | |
x = x.clamp(0, 1) | |
x = x.view(x.size(0), 1, 28, 28) | |
return x | |
num_epochs = 100 | |
batch_size = 128 | |
learning_rate = 1e-3 | |
ctx = mx.gpu() | |
def transform(data, label): | |
return (data.astype('float32') / 255 - 0.5) / 0.5, label.astype('float32') | |
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform) | |
#mnist_train = gluon.data.vision.MNIST(train=True, transform=transform) | |
dataloader = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True) | |
class autoencoder(gluon.Block): | |
def __init__(self): | |
super(autoencoder, self).__init__() | |
with self.name_scope(): | |
self.encoder = gluon.nn.Sequential('encoder_') | |
with self.encoder.name_scope(): | |
self.encoder.add(gluon.nn.Dense(128, activation='relu')) | |
self.encoder.add(gluon.nn.Dense(64, activation='relu')) | |
self.encoder.add(gluon.nn.Dense(12, activation='relu')) | |
self.encoder.add(gluon.nn.Dense(3)) | |
self.decoder = gluon.nn.Sequential('decoder_') | |
with self.decoder.name_scope(): | |
self.decoder.add(gluon.nn.Dense(12, activation='relu')) | |
self.decoder.add(gluon.nn.Dense(64, activation='relu')) | |
self.decoder.add(gluon.nn.Dense(128, activation='relu')) | |
self.decoder.add(gluon.nn.Dense(28 * 28, activation='tanh')) | |
def forward(self, x): | |
x = self.encoder(x) | |
x = self.decoder(x) | |
return x | |
def autoencode(): | |
model = autoencoder() | |
model.collect_params().initialize(mx.init.Xavier(), ctx=ctx) | |
criterion = gluon.loss.L2Loss() | |
optimizer = gluon.Trainer(model.collect_params(), 'adam', | |
{'learning_rate': learning_rate, | |
'wd': 1e-5}) | |
for epoch in range(num_epochs): | |
running_loss = 0.0 | |
n_total = 0.0 | |
for data in dataloader: | |
img, _ = data | |
img = img.reshape((img.shape[0], -1)).as_in_context(ctx) | |
with mx.autograd.record(): | |
output = model(img) | |
loss = criterion(output, img) | |
loss.backward() | |
optimizer.step(img.shape[0]) | |
running_loss += mx.nd.sum(loss).asscalar() | |
n_total += img.shape[0] | |
# ===================log======================== | |
print('epoch [{}/{}], loss:{:.4f}' | |
.format(epoch + 1, num_epochs, running_loss / n_total)) | |
if epoch % 10 == 0: | |
sqr = int(round(output.shape[1] ** .5)) | |
fn = './mlp_img/{}_autoencoder.png'.format(epoch) | |
im = output.asnumpy() | |
(minv, maxv) = (np.min(im), np.max(im)) | |
im = (((im + minv) / 2) * 255.0).astype(np.uint8) | |
im = TileImage(im) | |
sz = tuple(np.array([im.shape[1], im.shape[0]]) * 2) | |
im = Image.fromarray(im, mode='L').resize(sz) | |
im.save(fn) | |
model.save_params('./handwritten-digits-autoencoder.params') | |
if __name__ == "__main__": | |
autoencode() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment