-
Notifications
You must be signed in to change notification settings - Fork 5
/
mnist.cc
66 lines (60 loc) · 2.16 KB
/
mnist.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "./mnist.h"
int ReverseInt(int i) {
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return ((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}
void MNIST::read_mnist_data(std::string filename, Matrix& data) {
std::ifstream file(filename, std::ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
unsigned char label;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
file.read((char*)&n_rows, sizeof(n_rows));
file.read((char*)&n_cols, sizeof(n_cols));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
n_rows = ReverseInt(n_rows);
n_cols = ReverseInt(n_cols);
data.resize(n_cols * n_rows, number_of_images);
for (int i = 0; i < number_of_images; i++) {
for (int r = 0; r < n_rows; r++) {
for (int c = 0; c < n_cols; c++) {
unsigned char image = 0;
file.read((char*)&image, sizeof(image));
data(r * n_cols + c, i) = (float)image;
}
}
}
}
}
void MNIST::read_mnist_label(std::string filename, Matrix& labels, int num_classes) {
std::ifstream file(filename, std::ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
labels.resize(num_classes, number_of_images);
for (int i = 0; i < number_of_images; i++) {
unsigned char label = 0;
file.read((char*)&label, sizeof(label));
labels((int)label, i) = 1;
}
}
}
void MNIST::read() {
read_mnist_data(data_dir + "train-images-idx3-ubyte", train_data);
read_mnist_data(data_dir + "t10k-images-idx3-ubyte", test_data);
read_mnist_label(data_dir + "train-labels-idx1-ubyte", train_labels, 10);
read_mnist_label(data_dir + "t10k-labels-idx1-ubyte", test_labels, 10);
}