diff --git a/experiments/mnist/datasets.py b/experiments/mnist/datasets.py index 0c24c98..5f8eff3 100644 --- a/experiments/mnist/datasets.py +++ b/experiments/mnist/datasets.py @@ -18,7 +18,9 @@ import array import gzip import os +import pickle import struct +import tarfile import urllib.request from os import path @@ -47,6 +49,19 @@ def _one_hot(x, k, dtype=np.float32): return np.array(x[:, None] == np.arange(k), dtype) +def _unzip(file): + file = tarfile.open(file) + file.extractall(_DATA) + file.close() + return + + +def _unpickle(file): + with open(file, "rb") as fo: + dict = pickle.load(fo, encoding="bytes") + return dict + + def mnist_raw(): """Download and parse the raw MNIST dataset.""" # CVDF mirror of http://yann.lecun.com/exdb/mnist/ @@ -93,3 +108,46 @@ def mnist(permute_train=False): train_labels = train_labels[perm] return train_images, train_labels, test_images, test_labels + + +def cifar_raw(): + """Download, unzip and parse the raw cifar dataset.""" + + filename = "cifar-10-python.tar.gz" + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + _download(url, filename) + _unzip(path.join(_DATA, filename)) + + data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"] + data = [] + labels = [] + for batch in data_batches: + tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch)) + data.append(tmp_dict[b"data"]) + labels.append(tmp_dict[b"labels"]) + train_images = np.concatenate(data) + train_labels = np.concatenate(labels) + + test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch")) + test_images = test_dict[b"data"] + test_labels = np.array(test_dict[b"labels"]) + + return train_images, train_labels, test_images, test_labels + + +def cifar(permute_train=False): + """Download, parse and process cifar data to unit scale and one-hot labels.""" + + train_images, train_labels, test_images, test_labels = cifar_raw() + + train_images = train_images / np.float32(255.0) + test_images = test_images / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels