-
Notifications
You must be signed in to change notification settings - Fork 0
/
heatmap.py
267 lines (224 loc) · 11.6 KB
/
heatmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#slight modification after https://github.com/z1069614715/objectdetection_script/blob/master/yolo-gradcam/yolov8_heatmap.py
# Import necessary packages and suppress warnings
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
# Import various libraries for deep learning, computer vision, and utilities
import torch, yaml, cv2, os, shutil, sys
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from argparse import ArgumentParser
# Import specific modules from Ultralytics and PyTorch Grad-CAM libraries
from ultralytics.nn.tasks import attempt_load_weights
from ultralytics.utils.torch_utils import intersect_dicts
from ultralytics.utils.ops import xywh2xyxy, non_max_suppression
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
# Resize and pad image while meeting stride-multiple constraints
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
shape = im.shape[:2] # Current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # Only scale down; do not scale up (for better validation mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # Width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # Width, height padding
if auto: # Minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # Adjust width, height padding to match stride
elif scaleFill: # Stretch to fill new shape
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # Updated width, height ratios
dw /= 2 # Divide padding into two sides
dh /= 2
if shape[::-1] != new_unpad: # Resize if new shape differs from current shape
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
# Add border to resized image
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return im, ratio, (dw, dh)
class ActivationsAndGradients:
""" Class for extracting activations and registering gradients from targetted intermediate layers """
def __init__(self, model, target_layers, reshape_transform):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.handles = []
for target_layer in target_layers:
# Register forward hook for activations
self.handles.append(
target_layer.register_forward_hook(self.save_activation))
# Because of https://github.com/pytorch/pytorch/issues/61519,
# we don't use backward hook to record gradients.
# Register forward hook for gradients
self.handles.append(
target_layer.register_forward_hook(self.save_gradient))
def save_activation(self, module, input, output):
activation = output
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
def save_gradient(self, module, input, output):
if not hasattr(output, "requires_grad") or not output.requires_grad:
# Gradient registration only possible on tensors with requires_grad set to True
return
# Reverse order for gradient storage
def _store_grad(grad):
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
output.register_hook(_store_grad)
# Post-process results to extract relevant boxes and scores
def post_process(self, result):
logits_ = result[:, 4:]
boxes_ = result[:, :4]
sorted, indices = torch.sort(logits_.max(1)[0], descending=True)
return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()
def __call__(self, x):
self.gradients = []
self.activations = []
model_output = self.model(x)
post_result, pre_post_boxes, post_boxes = self.post_process(model_output[0])
return [[post_result, pre_post_boxes]]
def release(self):
# Remove registered hooks
for handle in self.handles:
handle.remove()
# Define custom YOLOv8 target class for specific detection tasks
class yolov8_target(torch.nn.Module):
def __init__(self, ouput_type, conf, ratio) -> None:
super().__init__()
self.ouput_type = ouput_type
self.conf = conf
self.ratio = ratio
def forward(self, data):
post_result, pre_post_boxes = data
result = []
for i in trange(int(post_result.size(0) * self.ratio)):
if float(post_result[i].max()) < self.conf:
break
if self.ouput_type == 'class' or self.ouput_type == 'all':
result.append(post_result[i].max())
elif self.ouput_type == 'box' or self.ouput_type == 'all':
for j in range(4):
result.append(pre_post_boxes[i, j])
return sum(result)
class yolov8_heatmap:
def __init__(self, weight, device, method, layer, backward_type, conf_threshold, ratio, show_box, renormalize):
device = torch.device(device)
# Load model checkpoint and set to eval mode
ckpt = torch.load(weight)
model_names = ckpt['model'].names
model = attempt_load_weights(weight, device)
model.info()
# Enable gradient calculation
for p in model.parameters():
p.requires_grad_(True)
model.eval()
# Initialize target for backpropagation
target = yolov8_target(backward_type, conf_threshold, ratio)
target_layers = [model.model[l] for l in layer]
method = eval(method)(model, target_layers, use_cuda=device.type == 'cuda')
method.activations_and_grads = ActivationsAndGradients(model, target_layers, None)
colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(int)
self.__dict__.update(locals())
def post_process(self, result):
result = non_max_suppression(result, conf_thres=self.conf_threshold, iou_thres=0.65)[0]
return result
def draw_detections(self, box, color, name, img):
# Draw bounding box on the image
xmin, ymin, xmax, ymax = list(map(int, list(box)))
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
return img
# Normalize CAM within bounding boxes to focus attention on regions of interest
def renormalize_cam_in_bounding_boxes(self, boxes, image_float_np, grayscale_cam):
renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
for x1, y1, x2, y2 in boxes:
x1, y1 = max(x1, 0), max(y1, 0)
x2, y2 = min(grayscale_cam.shape[1] - 1, x2), min(grayscale_cam.shape[0] - 1, y2)
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
renormalized_cam = scale_cam_image(renormalized_cam)
eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True)
return eigencam_image_renormalized
def process(self, img_path, save_path):
# Load and preprocess image
img = cv2.imread(img_path)
img = letterbox(img)[0]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.float32(img) / 255.0
tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
try:
# Generate the CAM
grayscale_cam = self.method(tensor, [self.target])
except AttributeError as e:
return
grayscale_cam = grayscale_cam[0, :]
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
# Process model output to get bounding boxes and renormalize if required
pred = self.model(tensor)[0]
pred = self.post_process(pred)
if self.renormalize:
cam_image = self.renormalize_cam_in_bounding_boxes(pred[:, :4].cpu().detach().numpy().astype(np.int32), img, grayscale_cam)
if self.show_box:
for data in pred:
data = data.cpu().detach().numpy()
cam_image = self.draw_detections(data[:4], self.colors[int(data[4:].argmax())], f'{self.model_names[int(data[4:].argmax())]} {float(data[4:].max()):.2f}', cam_image)
# Save the resulting CAM image
cam_image = Image.fromarray(cam_image)
cam_image.save(save_path)
def __call__(self, img_path, save_path):
# Remove output directory if it already exists
if os.path.exists(save_path):
shutil.rmtree(save_path)
# Create output directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)
# Process either a directory or a single image
if os.path.isdir(img_path):
for img_path_ in os.listdir(img_path):
self.process(f'{img_path}/{img_path_}', f'{save_path}/{img_path_}')
else:
self.process(img_path, f'{save_path}/result.png')
# Parse input arguments and set up configuration parameters
def get_params(args):
params = {
'weight': args.weights,
'device': args.device,
'method': args.method, # GradCAMPlusPlus, GradCAM, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
'layer': [args.layer],
'backward_type': 'class', # Options: 'class', 'box', 'all'
'conf_threshold': 0.25, # Default confidence threshold
'ratio': 0.5, # Ratio for filtering outputs
'show_box': args.box,
'renormalize': False
}
return params
if __name__ == '__main__':
parser = ArgumentParser(description='arguments')
# Define command-line arguments for model and method selection
parser.add_argument('--weights', nargs='?', type=str, default='best.pt',
help='Path to model checkpoint')
parser.add_argument('--method', nargs='?', type=str, default='XGradCAM',
help='CAM method for visualization')
parser.add_argument('--layer', nargs='?', type=int, default=18,
help='Layer number to target')
parser.add_argument('--input', nargs='?', type=str, default='DJI_0108.JPG',
help='Path to input image')
parser.add_argument('--device', nargs='?', type=str, default='cuda',
help='Device to run on: cuda, cpu, or mps')
parser.add_argument('--box', nargs='?', type=bool, default=False,
help='Option to show bounding boxes')
args = parser.parse_args()
# Instantiate and run the heatmap generation model
model = yolov8_heatmap(**get_params(args))
model(args.input, 'result')