-
Notifications
You must be signed in to change notification settings - Fork 0
/
Augment.py
124 lines (105 loc) · 4.16 KB
/
Augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import random
from logging import getLogger
from PIL import ImageFilter
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import json
import os
logger = getLogger()
class MultiCropDataset(datasets.ImageFolder):
def __init__(
self,
im_preprocessor,
text_tokenizer,
dataroot,
size_crops,
nmb_crops,
min_scale_crops,
max_scale_crops,
batch_size=512,
num_workers=25,
size_dataset=-1,
return_index=False,
):
super(MultiCropDataset, self).__init__(dataroot)
caption_path = os.path.join(dataroot, "captions_all_json.json")
self.captions = json.load(open(caption_path, "r"))
self.image_name = self.captions.keys()
self.dataroot = dataroot
self.batch_size = batch_size
self.num_workers = num_workers
self.im_preprocessor = im_preprocessor
self.text_tokenizer = text_tokenizer
assert len(size_crops) == len(nmb_crops)
assert len(min_scale_crops) == len(nmb_crops)
assert len(max_scale_crops) == len(nmb_crops)
if size_dataset >= 0:
self.samples = self.samples[:size_dataset]
self.return_index = return_index
color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
mean = [0.485, 0.456, 0.406]
std = [0.228, 0.224, 0.225]
trans = []
for i in range(len(size_crops)):
randomresizedcrop = transforms.RandomResizedCrop(
size_crops[i],
scale=(min_scale_crops[i], max_scale_crops[i]),
)
trans.extend([transforms.Compose([
randomresizedcrop,
transforms.RandomHorizontalFlip(p=0.5),
transforms.Compose(color_transform),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)])
] * nmb_crops[i])
self.trans = trans
def __getitem__(self, index):
image_id = self.image_name[index]
caption = self.captions[image_id]
image_path = os.path.join(self.dataroot, "images/{}.jpg".format(image_id))
image = self.loader(image_path)
multi_crops = list(map(lambda trans: trans(image), self.trans))
assert len(multi_crops) == 2, "For now please specify only two multi-crops"
assert multi_crops[0].shape[0] == 224, "1st image should be high resolution, i.e., 224 * 224"
assert multi_crops[1].shape[0] == 96, "2n image should be low resolution, i.e., 96 * 96"
high_res_image = multi_crops[0]
low_res_image = multi_crops[1]
high_res_image = self.im_preprocessor(high_res_image)
low_res_image = self.im_preprocessor(low_res_image)
caption = self.text_tokenizer(caption)
if self.return_index:
return (index, high_res_image, low_res_image, caption)
return (high_res_image, low_res_image, caption)
class PILRandomGaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image. Take the radius and probability of
application as the parameter.
This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = np.random.rand() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
def get_color_distortion(s=1.0):
# s is the strength of color distortion.
color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
return color_distort