-
Notifications
You must be signed in to change notification settings - Fork 11
/
image_utils.py
54 lines (42 loc) · 1.75 KB
/
image_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
# mean and std of ImageNet to use pre-trained VGG
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
normalize = transforms.Normalize(mean=IMAGENET_MEAN,
std=IMAGENET_STD)
denormalize = transforms.Normalize(mean=[-mean/std for mean, std in zip(IMAGENET_MEAN, IMAGENET_STD)],
std=[1/std for std in IMAGENET_STD])
unloader = transforms.ToPILImage()
class ImageFolder(torch.utils.data.Dataset):
def __init__(self, root_path, transform):
super(ImageFolder, self).__init__()
self.file_names = sorted(os.listdir(root_path))
self.root_path = root_path
self.transform = transform
def __len__(self):
return len(self.file_names)
def __getitem__(self, index):
image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
return self.transform(image)
def get_transformer(imsize=None, cropsize=None):
transformer = []
if imsize:
transformer.append(transforms.Resize(imsize))
if cropsize:
transformer.append(transforms.RandomCrop(cropsize)),
transformer.append(transforms.ToTensor())
transformer.append(normalize)
return transforms.Compose(transformer)
def imsave(tensor, path):
if tensor.is_cuda:
tensor = tensor.cpu()
tensor = torchvision.utils.make_grid(tensor)
torchvision.utils.save_image(denormalize(tensor).clamp_(0.0, 1.0), path)
return None
def imload(path, imsize=None, cropsize=None):
transformer = get_transformer(imsize, cropsize)
return transformer(Image.open(path).convert("RGB")).unsqueeze(0)