-
Notifications
You must be signed in to change notification settings - Fork 24
/
dogs.py
106 lines (83 loc) · 4.38 KB
/
dogs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import scipy.io
from os.path import join
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url, list_dir
class Dogs(VisionDataset):
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(Dogs, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.train = train
if download:
self.download()
split = self.load_split()
self.images_folder = join(self.root, 'Images')
self.annotations_folder = join(self.root, 'Annotation')
self._breeds = list_dir(self.images_folder)
self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split]
self._flat_breed_images = self._breed_images
def __len__(self):
return len(self._flat_breed_images)
def __getitem__(self, index):
image_name, target = self._flat_breed_images[index]
image_path = join(self.images_folder, image_name)
image = self.loader(image_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def download(self):
import tarfile
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
print('Files already downloaded and verified')
return
for filename in ['images', 'annotation', 'lists']:
tar_filename = filename + '.tar'
url = self.download_url_prefix + '/' + tar_filename
download_url(url, self.root, tar_filename, None)
print('Extracting downloaded file: ' + join(self.root, tar_filename))
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
tar_file.extractall(self.root)
os.remove(join(self.root, tar_filename))
def load_split(self):
if self.train:
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
else:
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
split = [item[0][0] for item in split]
labels = [item[0] - 1 for item in labels]
return list(zip(split, labels))
def stats(self):
counts = {}
for index in range(len(self._flat_breed_images)):
image_name, target_class = self._flat_breed_images[index]
if target_class not in counts.keys():
counts[target_class] = 1
else:
counts[target_class] += 1
print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()),
float(len(self._flat_breed_images)) / float(
len(counts.keys()))))
return counts
if __name__ == '__main__':
train_dataset = Dogs('./dogs', train=True, download=False)
test_dataset = Dogs('./dogs', train=False, download=False)