Last active
March 28, 2025 06:55
-
-
Save pbloem/e2a46efe5b1fd4c098cd249d8f60d2c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from urllib import request | |
import gzip | |
import pickle | |
import os | |
def load_mnist(final=False, flatten=True, verbose=False, normalize=True): | |
""" | |
Load the MNIST data. | |
:param final: If true, return the canonical test/train split. If false, split some validation data from the training | |
data and keep the test data hidden. | |
:param flatten: If true, each instance is flattened into a vector, so that the data is returns as a matrix with 768 | |
columns. If false, the data is returned as a 3-tensor preserving each image as a matrix. | |
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training | |
data and the corresponding classification labels as a numpy integer array. The second contains the test/validation | |
data in the same format. The last integer contains the number of classes (this is always 2 for this function). | |
""" | |
if not os.path.isfile('mnist.pkl'): | |
init() | |
xtrain, ytrain, xtest, ytest = load() | |
xtrain, ytrain, xtest, ytest = torch.from_numpy(xtrain), torch.from_numpy(ytrain), torch.from_numpy(xtest), torch.from_numpy(ytest) | |
xtl, xsl = xtrain.shape[0], xtest.shape[0] | |
if normalize: | |
xtrain = xtrain.to(torch.float) / 255. | |
xtest = xtest.to(torch.float) / 255. | |
if flatten: | |
xtrain = xtrain.reshape(xtl, -1) | |
xtest = xtest.reshape(xsl, -1) | |
if not final: | |
return (xtrain[:-5000], ytrain[:-5000]), (xtrain[-5000:], ytrain[-5000:]), 10 | |
return (xtrain, ytrain), (xtest, ytest), 10 | |
# Numpy-only MNIST loader. Courtesy of Hyeonseok Jung | |
# https://github.com/hsjeong5/MNIST-for-Numpy | |
filename = [ | |
["training_images","train-images-idx3-ubyte.gz"], | |
["test_images","t10k-images-idx3-ubyte.gz"], | |
["training_labels","train-labels-idx1-ubyte.gz"], | |
["test_labels","t10k-labels-idx1-ubyte.gz"] | |
] | |
def download_mnist(): | |
base_url = "https://peterbloem.nl/files/mnist/" # "http://yann.lecun.com/exdb/mnist/" | |
for name in filename: | |
print("Downloading "+name[1]+"...") | |
request.urlretrieve(base_url+name[1], name[1]) | |
print("Download complete.") | |
def save_mnist(): | |
mnist = {} | |
for name in filename[:2]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 1, 28, 28) | |
for name in filename[-2:]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) | |
with open("mnist.pkl", 'wb') as f: | |
pickle.dump(mnist,f) | |
print("Save complete.") | |
def init(): | |
download_mnist() | |
save_mnist() | |
def load(): | |
with open("mnist.pkl",'rb') as f: | |
mnist = pickle.load(f) | |
return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment