-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_transformer.py
31 lines (25 loc) · 1.2 KB
/
image_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch
import torchvision.transforms as transforms
# Transformations
class TwoCropTransform:
def __init__(self, transform, img_size):
self.transform = transform
self.img_size = img_size
color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
self.data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
def __call__(self, x):
return [self.transform(x), self.data_transforms(x)]
def rotation(input):
batch = input.shape[0]
target = torch.tensor(np.random.permutation([0,1,2,3] * (int(batch / 4) + 1)), device = input.device)[:batch]
target = target.long()
image = torch.zeros_like(input)
image.copy_(input)
for i in range(batch):
image[i, :, :, :] = torch.rot90(input[i, :, :, :], target[i], [1, 2])
return image, target