Skip to content

Commit

Permalink
Support for 3D Conv-Net (#466)
Browse files Browse the repository at this point in the history
* Modify grad-cam and base-cam to support 3d conv.

* Add image examples for 3D convolutions.

* Modify get_cam_image to increase readbability.

---------

Co-authored-by: Jacob Gildenblat <jacob.gildenblat@gmail.com>
  • Loading branch information
kevinkevin556 and jacobgil authored May 28, 2024
1 parent f0371ab commit 3f6b14d
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 91 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| -----------------|-----------------------|
| <img src="./examples/both_detection.png" width="256" height="256"> | <img src="./examples/cars_segmentation.png" width="256" height="200"> |

| Semantic Segmentation (3D) |
| -------------------------- |
| <img src="./examples/multiorgan_segmentation.gif" width="539">|

## Explaining similarity to other images / embeddings
<img src="./examples/embeddings.png">

Expand Down
Binary file added examples/multiorgan_segmentation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
162 changes: 77 additions & 85 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
import ttach as tta
from typing import Callable, List, Tuple, Optional

from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
from pytorch_grad_cam.utils.image import scale_cam_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection


class BaseCAM:
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
reshape_transform: Callable = None,
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None) -> None:
def __init__(
self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
reshape_transform: Callable = None,
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None,
) -> None:
self.model = model.eval()
self.target_layers = target_layers

Expand All @@ -34,63 +38,64 @@ def __init__(self,
else:
self.tta_transforms = tta_transforms

self.activations_and_grads = ActivationsAndGradients(
self.model, target_layers, reshape_transform)
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)

""" Get a vector of weights for every channel in the target layer.
Methods that return weights channels,
will typically need to only implement this function. """

def get_cam_weights(self,
input_tensor: torch.Tensor,
target_layers: List[torch.nn.Module],
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor) -> np.ndarray:
def get_cam_weights(
self,
input_tensor: torch.Tensor,
target_layers: List[torch.nn.Module],
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor,
) -> np.ndarray:
raise Exception("Not Implemented")

def get_cam_image(self,
input_tensor: torch.Tensor,
target_layer: torch.nn.Module,
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor,
eigen_smooth: bool = False) -> np.ndarray:

weights = self.get_cam_weights(input_tensor,
target_layer,
targets,
activations,
grads)
weighted_activations = weights[:, :, None, None] * activations
def get_cam_image(
self,
input_tensor: torch.Tensor,
target_layer: torch.nn.Module,
targets: List[torch.nn.Module],
activations: torch.Tensor,
grads: torch.Tensor,
eigen_smooth: bool = False,
) -> np.ndarray:
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
# 2D conv
if len(activations.shape) == 4:
weighted_activations = weights[:, :, None, None] * activations
# 3D conv
elif len(activations.shape) == 5:
weighted_activations = weights[:, :, None, None, None] * activations
else:
raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")

if eigen_smooth:
cam = get_2d_projection(weighted_activations)
else:
cam = weighted_activations.sum(axis=1)
return cam

def forward(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:

def forward(
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
) -> np.ndarray:
input_tensor = input_tensor.to(self.device)

if self.compute_input_gradient:
input_tensor = torch.autograd.Variable(input_tensor,
requires_grad=True)
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)

self.outputs = outputs = self.activations_and_grads(input_tensor)

if targets is None:
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
targets = [ClassifierOutputTarget(
category) for category in target_categories]
targets = [ClassifierOutputTarget(category) for category in target_categories]

if self.uses_gradients:
self.model.zero_grad()
loss = sum([target(output)
for target, output in zip(targets, outputs)])
loss = sum([target(output) for target, output in zip(targets, outputs)])
loss.backward(retain_graph=True)

# In most of the saliency attribution papers, the saliency is
Expand All @@ -102,25 +107,24 @@ def forward(self,
# This gives you more flexibility in case you just want to
# use all conv layers for example, all Batchnorm layers,
# or something else.
cam_per_layer = self.compute_cam_per_layer(input_tensor,
targets,
eigen_smooth)
cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
return self.aggregate_multi_layers(cam_per_layer)

def get_target_width_height(self,
input_tensor: torch.Tensor) -> Tuple[int, int]:
width, height = input_tensor.size(-1), input_tensor.size(-2)
return width, height
def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
if len(input_tensor.shape) == 4:
width, height = input_tensor.size(-1), input_tensor.size(-2)
return width, height
elif len(input_tensor.shape) == 5:
depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
return depth, width, height
else:
raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")

def compute_cam_per_layer(
self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool) -> np.ndarray:
activations_list = [a.cpu().data.numpy()
for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy()
for g in self.activations_and_grads.gradients]
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
) -> np.ndarray:
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)

cam_per_target_layer = []
Expand All @@ -134,36 +138,26 @@ def compute_cam_per_layer(
if i < len(grads_list):
layer_grads = grads_list[i]

cam = self.get_cam_image(input_tensor,
target_layer,
targets,
layer_activations,
layer_grads,
eigen_smooth)
cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
cam = np.maximum(cam, 0)
scaled = scale_cam_image(cam, target_size)
cam_per_target_layer.append(scaled[:, None, :])

return cam_per_target_layer

def aggregate_multi_layers(
self,
cam_per_target_layer: np.ndarray) -> np.ndarray:
def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
result = np.mean(cam_per_target_layer, axis=1)
return scale_cam_image(result)

def forward_augmentation_smoothing(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:
def forward_augmentation_smoothing(
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
) -> np.ndarray:
cams = []
for transform in self.tta_transforms:
augmented_tensor = transform.augment_image(input_tensor)
cam = self.forward(augmented_tensor,
targets,
eigen_smooth)
cam = self.forward(augmented_tensor, targets, eigen_smooth)

# The ttach library expects a tensor of size BxCxHxW
cam = cam[:, None, :, :]
Expand All @@ -178,19 +172,18 @@ def forward_augmentation_smoothing(self,
cam = np.mean(np.float32(cams), axis=0)
return cam

def __call__(self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module] = None,
aug_smooth: bool = False,
eigen_smooth: bool = False) -> np.ndarray:

def __call__(
self,
input_tensor: torch.Tensor,
targets: List[torch.nn.Module] = None,
aug_smooth: bool = False,
eigen_smooth: bool = False,
) -> np.ndarray:
# Smooth the CAM result with test time augmentation
if aug_smooth is True:
return self.forward_augmentation_smoothing(
input_tensor, targets, eigen_smooth)
return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)

return self.forward(input_tensor,
targets, eigen_smooth)
return self.forward(input_tensor, targets, eigen_smooth)

def __del__(self):
self.activations_and_grads.release()
Expand All @@ -202,6 +195,5 @@ def __exit__(self, exc_type, exc_value, exc_tb):
self.activations_and_grads.release()
if isinstance(exc_value, IndexError):
# Handle IndexError here...
print(
f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
return True
13 changes: 12 additions & 1 deletion pytorch_grad_cam/grad_cam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np

from pytorch_grad_cam.base_cam import BaseCAM


Expand All @@ -18,4 +19,14 @@ def get_cam_weights(self,
target_category,
activations,
grads):
return np.mean(grads, axis=(2, 3))
# 2D image
if len(grads.shape) == 4:
return np.mean(grads, axis=(2, 3))

# 3D image
elif len(grads.shape) == 5:
return np.mean(grads, axis=(2, 3, 4))

else:
raise ValueError("Invalid grads shape."
"Shape of grads should be 4 (2D image) or 5 (3D image).")
16 changes: 11 additions & 5 deletions pytorch_grad_cam/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import math
from typing import Dict, List

import cv2
import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.ndimage import zoom
from torchvision.transforms import Compose, Normalize, ToTensor
from typing import List, Dict
import math


def preprocess_image(
Expand Down Expand Up @@ -163,7 +165,11 @@ def scale_cam_image(cam, target_size=None):
img = img - np.min(img)
img = img / (1e-7 + np.max(img))
if target_size is not None:
if len(img.shape) > 3:
img = zoom(np.float32(img), [(t_s/i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
else:
img = cv2.resize(np.float32(img), target_size)

result.append(img)
result = np.float32(result)

Expand Down

0 comments on commit 3f6b14d

Please sign in to comment.