Last active
November 5, 2024 16:40
-
-
Save pbloem/bd8348d58251872d9ca10de4816945e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -- assignment 1 -- | |
import numpy as np | |
from urllib import request | |
import gzip | |
import pickle | |
import os | |
def load_synth(num_train=60_000, num_val=10_000, seed=0): | |
""" | |
Load some very basic synthetic data that should be easy to classify. Two features, so that we can plot the | |
decision boundary (which is an ellipse in the feature space). | |
:param num_train: Number of training instances | |
:param num_val: Number of test/validation instances | |
:param num_features: Number of features per instance | |
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training | |
data with 2 features as a numpy floating point array, and the corresponding classification labels as a numpy | |
integer array. The second contains the test/validation data in the same format. The last integer contains the | |
number of classes (this is always 2 for this function). | |
""" | |
np.random.seed(seed) | |
THRESHOLD = 0.6 | |
quad = np.asarray([[1, -0.05], [1, .4]]) | |
ntotal = num_train + num_val | |
x = np.random.randn(ntotal, 2) | |
# compute the quadratic form | |
q = np.einsum('bf, fk, bk -> b', x, quad, x) | |
y = (q > THRESHOLD).astype(int) | |
return (x[:num_train, :], y[:num_train]), (x[num_train:, :], y[num_train:]), 2 | |
def load_mnist(final=False, flatten=True): | |
""" | |
Load the MNIST data. | |
:param final: If true, return the canonical test/train split. If false, split some validation data from the training | |
data and keep the test data hidden. | |
:param flatten: If true, each instance is flattened into a vector, so that the data is returns as a matrix with 768 | |
columns. If false, the data is returned as a 3-tensor preserving each image as a matrix. | |
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training | |
data and the corresponding classification labels as a numpy integer array. The second contains the test/validation | |
data in the same format. The last integer contains the number of classes (this is always 2 for this function). | |
""" | |
if not os.path.isfile('mnist.pkl'): | |
init() | |
xtrain, ytrain, xtest, ytest = load() | |
xtl, xsl = xtrain.shape[0], xtest.shape[0] | |
if flatten: | |
xtrain = xtrain.reshape(xtl, -1) | |
xtest = xtest.reshape(xsl, -1) | |
if not final: # return the flattened images | |
return (xtrain[:-5000], ytrain[:-5000]), (xtrain[-5000:], ytrain[-5000:]), 10 | |
return (xtrain, ytrain), (xtest, ytest), 10 | |
# Numpy-only MNIST loader. Courtesy of Hyeonseok Jung | |
# https://github.com/hsjeong5/MNIST-for-Numpy | |
filename = [ | |
["training_images","train-images-idx3-ubyte.gz"], | |
["test_images","t10k-images-idx3-ubyte.gz"], | |
["training_labels","train-labels-idx1-ubyte.gz"], | |
["test_labels","t10k-labels-idx1-ubyte.gz"] | |
] | |
def download_mnist(): | |
base_url = "https://peterbloem.nl/files/mnist/" # "http://yann.lecun.com/exdb/mnist/" | |
for name in filename: | |
print("Downloading "+name[1]+"...") | |
request.urlretrieve(base_url+name[1], name[1]) | |
print("Download complete.") | |
def save_mnist(): | |
mnist = {} | |
for name in filename[:2]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28) | |
for name in filename[-2:]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) | |
with open("mnist.pkl", 'wb') as f: | |
pickle.dump(mnist,f) | |
print("Save complete.") | |
def init(): | |
download_mnist() | |
save_mnist() | |
def load(): | |
with open("mnist.pkl",'rb') as f: | |
mnist = pickle.load(f) | |
return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
My bad. Thanks for the pointer.