-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransforms.py
130 lines (115 loc) · 3.4 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
from pprint import pformat
import torch
from torchvision.transforms import (
ColorJitter,
Compose,
GaussianBlur,
Lambda,
RandomApply,
RandomGrayscale,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
RandomResizedCrop,
ToPILImage,
ToTensor,
CenterCrop,
)
logger = logging.getLogger(__name__)
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi
def _s2_l1c_to_rgb(image):
rgb_composite = image[[3, 2, 1]]
rgb_composite = (rgb_composite / 3558) * 255
rgb_composite = torch.clamp(rgb_composite, 0, 255)
rgb_composite = rgb_composite.to(torch.uint8)
return rgb_composite
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi
def _s2_l2a_to_rgb(image):
rgb_composite = image[[3, 2, 1]]
rgb_composite = (rgb_composite / 2000) * 255
rgb_composite = torch.clamp(rgb_composite, 0, 255)
rgb_composite = rgb_composite.to(torch.uint8)
return rgb_composite
def image_transform(
modality, clip_transform, clip_normalization=False, use_augmentations=False
):
transforms = []
if modality == "RGB":
pass
elif modality == "RGB_L2A":
transforms.extend(
[
_s2_l2a_to_rgb,
ToPILImage("RGB"),
]
)
elif modality == "RGB_L1C":
transforms.extend(
[
_s2_l1c_to_rgb,
ToPILImage("RGB"),
]
)
elif modality == "S2_L2A":
transforms.append(
Lambda(
lambda img: torch.cat(
(
img[:10, :, :],
torch.zeros(
(1, img[0].shape[0], img[0].shape[1]), dtype=torch.float32
),
img[10:, :, :],
),
dim=0,
)
),
)
elif modality == "S2_L1C":
transforms.append(Lambda(lambda image: image.float()))
else:
raise ValueError(f"Unknown modality {modality}")
transforms.extend(
[
transform
for transform in clip_transform.transforms
if (
isinstance(transform, Resize)
or isinstance(transform, CenterCrop)
or isinstance(transform, RandomResizedCrop)
)
]
)
if use_augmentations:
transforms.extend(
[
RandomHorizontalFlip(p=0.5),
RandomVerticalFlip(p=0.5),
]
)
if "RGB" in modality:
transforms.extend(
[
RandomApply([ColorJitter(0.4, 0.4, 0.4)], p=0.8),
RandomGrayscale(p=0.2),
RandomApply(
[GaussianBlur(kernel_size=(3, 7), sigma=(0.1, 2.0))], p=0.4
),
]
)
transforms.append(
Lambda(
lambda image: (
image if isinstance(image, torch.Tensor) else ToTensor()(image)
)
)
)
if not clip_normalization:
normalization = Lambda(lambda image: torch.clamp(image / 10000.0, 0, 1))
else:
normalization = clip_transform.transforms[-1]
transforms.append(normalization)
# TODO: Remove this
logger.info(f"Using {modality} with \n {pformat(transforms)}")
return Compose(transforms)