Last active
November 21, 2017 23:00
-
-
Save krishpop/dda70498d31b46efdb78f495a6754ab8 to your computer and use it in GitHub Desktop.
TF Tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataSet: | |
"""Base data set class | |
""" | |
def __init__(self, shuffle=True, labeled=True, **data_dict): | |
assert '_data' in data_dict | |
if labeled: | |
assert '_labels' in data_dict | |
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0] | |
self._labeled = labeled | |
self._shuffle = shuffle | |
self.__dict__.update(data_dict) | |
self._num_samples = self._data.shape[0] | |
self._index_in_epoch = 0 | |
self._epochs_trained = 0 | |
self._batch_number = 0 | |
if self._shuffle: | |
self._shuffle_data() | |
def __len__(self): | |
return len(self._data) + len(self._test_data) | |
@property | |
def epochs_trained(self): | |
return self._epochs_trained | |
@epochs_trained.setter | |
def epochs_trained(self, new_epochs_trained): | |
self._epochs_trained = new_epochs_trained | |
@property | |
def batch_number(self): | |
return self._batch_number | |
@property | |
def index_in_epoch(self): | |
return self._index_in_epoch | |
@property | |
def num_samples(self): | |
return self._num_samples | |
@property | |
def data(self): | |
return self._data | |
@property | |
def labels(self): | |
return self._labels | |
@property | |
def labeled(self): | |
return self._labeled | |
@property | |
def test_data(self): | |
return self._test_data | |
@property | |
def test_labels(self): | |
return self._test_labels | |
@classmethod | |
def load(cls, filename): | |
data_dict = np.load(filename) | |
labeled = data_dict['_labeled'] | |
return cls(labeled=labeled, **data_dict) | |
def save(self, filename): | |
data_dict = self.__dict__ | |
np.savez_compressed(filename, **data_dict) | |
def _shuffle_data(self): | |
shuffled_idx = np.arange(self._num_samples) | |
np.random.shuffle(shuffled_idx) | |
self._data = self._data[shuffled_idx] | |
if self._labeled: | |
self._labels = self._labels[shuffled_idx] | |
def next_batch(self, batch_size): | |
assert batch_size <= self._num_samples | |
start = self._index_in_epoch | |
if start + batch_size > self._num_samples: | |
self._epochs_trained += 1 | |
self._batch_number = 0 | |
data_batch = self._data[start:] | |
if self._labeled: | |
labels_batch = self._labels[start:] | |
remaining = batch_size - (self._num_samples - start) | |
if self._shuffle: | |
self._shuffle_data() | |
start = 0 | |
data_batch = np.concatenate([data_batch, self._data[:remaining]], | |
axis=0) | |
if self._labeled: | |
labels_batch = np.concatenate([labels_batch, | |
self._labels[:remaining]], | |
axis=0) | |
self._index_in_epoch = remaining | |
else: | |
data_batch = self._data[start:start + batch_size] | |
if self._labeled: | |
labels_batch = self._labels[start:start + batch_size] | |
self._index_in_epoch = start + batch_size | |
self._batch_number += 1 | |
batch = (data_batch, labels_batch) if self._labeled else data_batch | |
return batch | |
""" | |
## example | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
filename = 'filename.csv' | |
csv_data = pd.read_csv(filename) | |
train, test = train_test_split(csv_data, train_size=.9) | |
data_dict = {'_data': train, '_test_data': test} | |
data = DataSet(labeled=False, **data_dict) | |
num_steps = 1000 | |
for _ in range(num_steps): | |
batch = data.next_batch(100) | |
# do something with batch | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"%matplotlib inline\n", | |
"\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"from tensorflow.examples.tutorials.mnist import input_data\n", | |
"\n", | |
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.argmax(mnist.train.labels[1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(784,)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mnist.train.images[1].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x11cb3fa50>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADhFJREFUeJzt3V2MVPUZx/HfU9Eb9EJZuhLFxRqD\nUS/QrKYXSDRWFGMC3BhfYmiqrDGaFO1F8SXWBEXTVCvcoGskYuNbA2wkBquWNECThvBmfdkFtQYF\ngiyIiRovrO7Tizk0q+75n2HmzJxZnu8n2ezMeebMPB73x5kz/znnb+4uAPH8rOoGAFSD8ANBEX4g\nKMIPBEX4gaAIPxAU4QeCIvxAUIQfCGpCO1/MzPg6IdBi7m71PK6pPb+ZXWNmu83sIzNb3MxzAWgv\na/S7/WZ2gqQPJF0laZ+krZJudPfBxDrs+YEWa8ee/1JJH7n7x+7+raSXJc1t4vkAtFEz4T9D0t5R\n9/dly37AzPrMbJuZbWvitQCUrOUf+Ll7v6R+ibf9QCdpZs+/X9LUUffPzJYBGAeaCf9WSeea2dlm\ndpKkGyStK6ctAK3W8Nt+d//OzO6S9IakEyStdPf3S+sMQEs1PNTX0ItxzA+0XFu+5ANg/CL8QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ\nhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIan6JYkM9sj6StJ30v6zt17\ny2gK7dPT05Os33bbbcn6/fffn6ynZoE2S08mOzQ0lKw/8MADyfrAwECyHl1T4c9c4e6HS3geAG3E\n234gqGbD75LeNLPtZtZXRkMA2qPZt/0z3X2/mf1c0ltmtsvdN41+QPaPAv8wAB2mqT2/u+/Pfg9L\nGpB06RiP6Xf3Xj4MBDpLw+E3s4lmdsrR25JmS3qvrMYAtFYzb/u7JQ1kwzUTJL3o7n8rpSsALWep\ncdjSX8ysfS8WyOTJk3Nr9957b3Ldm2++OVmfNGlSsl40Vt/MOH/R3+bevXuT9UsuuSS3dvjw8Ts6\n7e7pDZthqA8IivADQRF+ICjCDwRF+IGgCD8QFEN940DRabNLlizJrRX9/231cNuhQ4eS9ZSurq5k\nfdq0acn64OBgbu2CCy5opKVxgaE+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/zjwNatW5P1iy++\nOLfW7Dh/aqxckq644opkvZlTZ2fOnJmsb9y4MVlP/bdPmFDGhas7E+P8AJIIPxAU4QeCIvxAUIQf\nCIrwA0ERfiAoxvk7wHnnnZesF43zf/7557m1ovPpi8bh77777mR90aJFyfrSpUtza59++mly3SJF\nf7sjIyO5tTvuuCO5bn9/f0M9dQLG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIXj/Ga2UtJ1kobd\n/cJs2WmSXpE0TdIeSde7+xeFL8Y4f0OKvgeQGqtvdirqvr6+ZH3FihXJemqa7B07diTXnT9/frK+\nevXqZD31t3366acn1x3PU3iXOc7/nKRrfrRssaQN7n6upA3ZfQDjSGH43X2TpCM/WjxX0qrs9ipJ\n80ruC0CLNXrM3+3uB7Lbn0nqLqkfAG3S9IXM3N1Tx/Jm1icpfeAIoO0a3fMfNLMpkpT9Hs57oLv3\nu3uvu/c2+FoAWqDR8K+TtCC7vUDSq+W0A6BdCsNvZi9J+pek6Wa2z8xulfSYpKvM7ENJv8ruAxhH\nCo/53f3GnNKVJfeCHLt27arstYuuB7B79+5kPXWtgaJrBSxenB5BLppzoJXffzge8A0/ICjCDwRF\n+IGgCD8QFOEHgiL8QFDH7zzFgcyaNSu3VnQ6cNFQ3tDQULI+ffr0ZH3Lli25tcmTJyfXLTrdvKj3\nOXPmJOvRsecHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY5z8O3HTTTbm1hQsXJtctOi22jku7J+up\nsfxmTsmVpOXLlyfrRZcGj449PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTj/ca5onL7K9Tdv3pxc\n95577knWGcdvDnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiqcJzfzFZKuk7SsLtfmC17SNJCSUcv\nnH6fu69vVZNIe/HFF3NrPT09yXW7urqS9aLr/k+cODFZT3nwwQeTdcbxW6uePf9zkq4ZY/mf3X1G\n9kPwgXGmMPzuvknSkTb0AqCNmjnmv8vM3jGzlWZ2amkdAWiLRsO/QtI5kmZIOiDp8bwHmlmfmW0z\ns20NvhaAFmgo/O5+0N2/d/cRSc9IujTx2H5373X33kabBFC+hsJvZlNG3Z0v6b1y2gHQLvUM9b0k\n6XJJXWa2T9IfJF1uZjMkuaQ9km5vYY8AWsCaPV/7mF7MrH0vhlIUjfM//PDDyfq8efNyazt37kyu\nO2fOnGS96Lr+Ubl7ekKEDN/wA4Ii/EBQhB8IivADQRF+ICjCDwTFUF+dUlNNHzp0KLcW3euvv55b\nu/rqq5PrFl26+8knn2yop+MdQ30Akgg/EBThB4Ii/EBQhB8IivADQRF+ICim6M7MmjUrWX/88dwr\nlWnXrl3JdW+55ZaGejoePPLII7m12bNnJ9edPn162e1gFPb8QFCEHwiK8ANBEX4gKMIPBEX4gaAI\nPxBUmHH+1Pn4kvTUU08l68PDw7m1yOP4RVN0P/3007k1s7pOO0eLsOcHgiL8QFCEHwiK8ANBEX4g\nKMIPBEX4gaAKx/nNbKqk5yV1S3JJ/e6+zMxOk/SKpGmS9ki63t2/aF2rzZk/f36yXnTu+MaNG8ts\nZ9womqJ7zZo1yXpquxbNGVF0nQQ0p549/3eSfufu50v6paQ7zex8SYslbXD3cyVtyO4DGCcKw+/u\nB9x9R3b7K0lDks6QNFfSquxhqyTNa1WTAMp3TMf8ZjZN0kWStkjqdvcDWekz1Q4LAIwTdX+338xO\nlrRG0iJ3/3L097Ld3fPm4TOzPkl9zTYKoFx17fnN7ETVgv+Cu6/NFh80sylZfYqkMc98cfd+d+91\n994yGgZQjsLwW20X/6ykIXd/YlRpnaQF2e0Fkl4tvz0ArVI4RbeZzZS0WdK7kkayxfepdtz/V0ln\nSfpEtaG+IwXPVdkU3UVDVkNDQ8n64OBgbu3RRx9t6rm3b9+erBfp6enJrV122WXJdYuGQOfNS3+O\nW3Raburva9myZcl1i6boxtjqnaK78Jjf3f8pKe/JrjyWpgB0Dr7hBwRF+IGgCD8QFOEHgiL8QFCE\nHwiqcJy/1BercJy/yOrVq5P11Hh3M2PdkrRz585kvchZZ52VW5s0aVJy3WZ7L1o/NUX38uXLk+se\nPnw4WcfY6h3nZ88PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzp8pmsJ7/fr1ubXe3vRFikZGRpL1\nVo61F637zTffJOtFl89eunRpsj4wMJCso3yM8wNIIvxAUIQfCIrwA0ERfiAowg8ERfiBoBjnr1NX\nV1dubcmSJU09d19fejaztWvXJuvNnPdedO18pskefxjnB5BE+IGgCD8QFOEHgiL8QFCEHwiK8ANB\nFY7zm9lUSc9L6pbkkvrdfZmZPSRpoaRD2UPvc/f8k941vsf5gfGi3nH+esI/RdIUd99hZqdI2i5p\nnqTrJX3t7n+qtynCD7ReveGfUMcTHZB0ILv9lZkNSTqjufYAVO2YjvnNbJqkiyRtyRbdZWbvmNlK\nMzs1Z50+M9tmZtua6hRAqer+br+ZnSxpo6RH3H2tmXVLOqza5wBLVDs0+E3Bc/C2H2ix0o75JcnM\nTpT0mqQ33P2JMerTJL3m7hcWPA/hB1qstBN7rHZp2GclDY0OfvZB4FHzJb13rE0CqE49n/bPlLRZ\n0ruSjl6D+j5JN0qaodrb/j2Sbs8+HEw9F3t+oMVKfdtfFsIPtB7n8wNIIvxAUIQfCIrwA0ERfiAo\nwg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRVeAHPkh2W9Mmo+13Zsk7Uqb11al8SvTWq\nzN566n1gW8/n/8mLm21z997KGkjo1N46tS+J3hpVVW+87QeCIvxAUFWHv7/i10/p1N46tS+J3hpV\nSW+VHvMDqE7Ve34AFakk/GZ2jZntNrOPzGxxFT3kMbM9Zvaumb1d9RRj2TRow2b23qhlp5nZW2b2\nYfZ7zGnSKurtITPbn227t83s2op6m2pm/zCzQTN738x+my2vdNsl+qpku7X9bb+ZnSDpA0lXSdon\naaukG919sK2N5DCzPZJ63b3yMWEzmyXpa0nPH50Nycz+KOmIuz+W/cN5qrv/vkN6e0jHOHNzi3rL\nm1n616pw25U543UZqtjzXyrpI3f/2N2/lfSypLkV9NHx3H2TpCM/WjxX0qrs9irV/njaLqe3juDu\nB9x9R3b7K0lHZ5audNsl+qpEFeE/Q9LeUff3qbOm/HZJb5rZdjPrq7qZMXSPmhnpM0ndVTYzhsKZ\nm9vpRzNLd8y2a2TG67Lxgd9PzXT3iyXNkXRn9va2I3ntmK2ThmtWSDpHtWncDkh6vMpmspml10ha\n5O5fjq5Vue3G6KuS7VZF+PdLmjrq/pnZso7g7vuz38OSBlQ7TOkkB49Okpr9Hq64n/9z94Pu/r27\nj0h6RhVuu2xm6TWSXnD3tdniyrfdWH1Vtd2qCP9WSeea2dlmdpKkGyStq6CPnzCzidkHMTKziZJm\nq/NmH14naUF2e4GkVyvs5Qc6ZebmvJmlVfG267gZr9297T+SrlXtE///SLq/ih5y+vqFpH9nP+9X\n3Zukl1R7G/hf1T4buVXSJEkbJH0o6e+STuug3v6i2mzO76gWtCkV9TZTtbf070h6O/u5tuptl+ir\nku3GN/yAoPjADwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUP8DUODl2qszuRAAAAAASUVORK5C\nYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x11ddd3510>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(mnist.train.images[1].reshape((28,28)), cmap=plt.cm.gray)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"$$y = xW + b$$\n", | |
"\n", | |
"$$ \\mathcal{L}(\\dot{y}, y) = -\\sum y * \\log(\\dot{y}) + (1 - y) * (1-\\log(\\dot{y})) $$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(10000, 784)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mnist.test.images.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sess.close()\n", | |
"tf.reset_default_graph()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"sess = tf.InteractiveSession()\n", | |
"\n", | |
"np.random.seed(52)\n", | |
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n", | |
"\n", | |
"x = tf.placeholder(shape=[None,784], dtype=tf.float32)\n", | |
"y_ = tf.placeholder(shape=[None,10], dtype=tf.float32)\n", | |
"\n", | |
"w = tf.get_variable('w', shape=[784,10], dtype=tf.float32, initializer=tf.zeros_initializer())\n", | |
"b = tf.get_variable('b', shape=[10], dtype=tf.float32, initializer=tf.zeros_initializer())\n", | |
"\n", | |
"y = tf.add(tf.matmul(x, w, name='wx'), b, name='y')\n", | |
"\n", | |
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_, name='loss'), name='avg_loss')\n", | |
"optimizer = tf.train.GradientDescentOptimizer(0.5)\n", | |
"opt_op = optimizer.minimize(loss)\n", | |
"\n", | |
"tf.global_variables_initializer().run()\n", | |
"\n", | |
"batch_size = 100\n", | |
"\n", | |
"for step in range(1000):\n", | |
" batch_x, batch_y = mnist.train.next_batch(100)\n", | |
" _ = sess.run(opt_op, feed_dict={x: batch_x, y_: batch_y})\n", | |
"\n", | |
"predictions = sess.run(y, feed_dict={x: mnist.test.images})\n", | |
"\n", | |
"actual_predictions = np.argmax(predictions, axis=1)\n", | |
"nsamples, ndims = mnist.test.labels.shape\n", | |
"acc = np.sum(np.equal(actual_predictions, np.argmax(mnist.test.labels, axis=1))) / float(nsamples)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.91990000000000005" | |
] | |
}, | |
"execution_count": 64, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"acc" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Important TF Guides:
MNIST Beginners Tutorial