-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathwrite_imagenet.py
76 lines (69 loc) · 3.04 KB
/
write_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
'''
This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet
'''
from torch.utils.data import Subset, ConcatDataset
from ffcv.writer import DatasetWriter
from ffcv.fields import IntField, RGBImageField
from torchvision.datasets import CIFAR10, ImageFolder
from argparse import ArgumentParser
from fastargs import Section, Param
from fastargs.validation import And, OneOf
from fastargs.decorators import param, section
from fastargs import get_current_config
from customdataset import CustomizeDataset
Section('cfg', 'arguments to give the writer').params(
dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'),
split=Param(And(str, OneOf(['train', 'val'])), 'Train or val set', required=True),
data_dir=Param(str, 'Where to find the PyTorch dataset', required=True),
write_path=Param(str, 'Where to write the new dataset', required=True),
write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'),
max_resolution=Param(int, 'Max image side length', required=True),
num_workers=Param(int, 'Number of workers to use', default=16),
chunk_size=Param(int, 'Chunk size for writing', default=100),
jpeg_quality=Param(float, 'Quality of jpeg images', default=90),
subset=Param(int, 'How many images to use (-1 for all)', default=-1),
compress_probability=Param(float, 'compress probability', default=None)
)
@section('cfg')
@param('dataset')
@param('split')
@param('data_dir')
@param('write_path')
@param('max_resolution')
@param('num_workers')
@param('chunk_size')
@param('subset')
@param('jpeg_quality')
@param('write_mode')
@param('compress_probability')
def main(dataset, split, data_dir, write_path, max_resolution, num_workers,
chunk_size, subset, jpeg_quality, write_mode,
compress_probability):
if dataset == 'cifar':
my_dataset = CIFAR10(root=data_dir, train=(split == 'train'), download=True)
elif dataset == 'imagenet':
my_dataset = ImageFolder(root=data_dir)
elif dataset == 'imagenet_aug':
my_dataset_1 = ImageFolder(root=data_dir)
data_dir_2 = "path/to/generated/imagenet/train/or/val"
my_dataset_2 = ImageFolder(root=data_dir_2)
my_dataset = ConcatDataset([my_dataset_1, my_dataset_2])
else:
raise ValueError('Unrecognized dataset', dataset)
if subset > 0: my_dataset = Subset(my_dataset, range(subset))
writer = DatasetWriter(write_path, {
'image': RGBImageField(write_mode=write_mode,
max_resolution=max_resolution,
compress_probability=compress_probability,
jpeg_quality=jpeg_quality),
'label': IntField(),
}, num_workers=num_workers)
writer.from_indexed_dataset(my_dataset, chunksize=chunk_size)
if __name__ == '__main__':
config = get_current_config()
parser = ArgumentParser()
config.augment_argparse(parser)
config.collect_argparse_args(parser)
config.validate(mode='stderr')
config.summary()
main()