diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 061ae0bbe..415a1a404 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -8,6 +8,7 @@ import matplotlib import numpy as np +import numpy.typing as npt from matplotlib import cm, colors, pyplot as plt from matplotlib.axes import Axes from matplotlib.collections import LineCollection @@ -47,11 +48,11 @@ class VisualizeSign(Enum): all = 4 -def _prepare_image(attr_visual: ndarray) -> ndarray: +def _prepare_image(attr_visual: npt.NDArray) -> npt.NDArray: return np.clip(attr_visual.astype(int), 0, 255) -def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray: +def _normalize_scale(attr: npt.NDArray, scale_factor: float) -> npt.NDArray: assert scale_factor != 0, "Cannot normalize by scale factor = 0" if abs(scale_factor) < 1e-5: warnings.warn( @@ -64,7 +65,9 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray: return np.clip(attr_norm, -1, 1) -def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float: +def _cumulative_sum_threshold( + values: npt.NDArray, percentile: Union[int, float] +) -> float: # given values should be non-negative assert percentile >= 0 and percentile <= 100, ( "Percentile for thresholding must be " "between 0 and 100 inclusive." @@ -72,15 +75,16 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> sorted_vals = np.sort(values.flatten()) cum_sums = np.cumsum(sorted_vals) threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] + # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`. return sorted_vals[threshold_id] def _normalize_attr( - attr: ndarray, + attr: npt.NDArray, sign: str, outlier_perc: Union[int, float] = 2, reduction_axis: Optional[int] = None, -) -> ndarray: +) -> npt.NDArray: attr_combined = attr if reduction_axis is not None: attr_combined = np.sum(attr, axis=reduction_axis) @@ -130,7 +134,7 @@ def _initialize_cmap_and_vmin_vmax( def _visualize_original_image( plt_axis: Axes, - original_image: Optional[ndarray], + original_image: Optional[npt.NDArray], **kwargs: Any, ) -> None: assert ( @@ -143,7 +147,7 @@ def _visualize_original_image( def _visualize_heat_map( plt_axis: Axes, - norm_attr: ndarray, + norm_attr: npt.NDArray, cmap: Union[str, Colormap], vmin: float, vmax: float, @@ -155,8 +159,8 @@ def _visualize_heat_map( def _visualize_blended_heat_map( plt_axis: Axes, - original_image: ndarray, - norm_attr: ndarray, + original_image: npt.NDArray, + norm_attr: npt.NDArray, cmap: Union[str, Colormap], vmin: float, vmax: float, @@ -176,8 +180,8 @@ def _visualize_blended_heat_map( def _visualize_masked_image( plt_axis: Axes, sign: str, - original_image: ndarray, - norm_attr: ndarray, + original_image: npt.NDArray, + norm_attr: npt.NDArray, **kwargs: Any, ) -> None: assert VisualizeSign[sign].value != VisualizeSign.all.value, ( @@ -190,8 +194,8 @@ def _visualize_masked_image( def _visualize_alpha_scaling( plt_axis: Axes, sign: str, - original_image: ndarray, - norm_attr: ndarray, + original_image: npt.NDArray, + norm_attr: npt.NDArray, **kwargs: Any, ) -> None: assert VisualizeSign[sign].value != VisualizeSign.all.value, ( @@ -210,8 +214,8 @@ def _visualize_alpha_scaling( def visualize_image_attr( - attr: ndarray, - original_image: Optional[ndarray] = None, + attr: npt.NDArray, + original_image: Optional[npt.NDArray] = None, method: str = "heat_map", sign: str = "absolute_value", plt_fig_axis: Optional[Tuple[Figure, Axes]] = None, @@ -417,8 +421,8 @@ def visualize_image_attr( def visualize_image_attr_multiple( - attr: ndarray, - original_image: Union[None, ndarray], + attr: npt.NDArray, + original_image: Union[None, npt.NDArray], methods: List[str], signs: List[str], titles: Optional[List[str]] = None, @@ -526,9 +530,9 @@ def visualize_image_attr_multiple( def visualize_timeseries_attr( - attr: ndarray, - data: ndarray, - x_values: Optional[ndarray] = None, + attr: npt.NDArray, + data: npt.NDArray, + x_values: Optional[npt.NDArray] = None, method: str = "overlay_individual", sign: str = "absolute_value", channel_labels: Optional[List[str]] = None, diff --git a/tests/attr/test_gradient_shap.py b/tests/attr/test_gradient_shap.py index d96042179..5193e5bba 100644 --- a/tests/attr/test_gradient_shap.py +++ b/tests/attr/test_gradient_shap.py @@ -5,6 +5,7 @@ from typing import cast, Tuple import numpy as np +import numpy.typing as npt import torch from captum._utils.typing import Tensor from captum.attr._core.gradient_shap import GradientShap @@ -132,7 +133,7 @@ def generate_baselines_with_inputs(inputs: Tensor) -> Tensor: inp_shape = cast(Tuple[int, ...], inputs.shape) return torch.arange(0.0, inp_shape[1] * 2.0).reshape(2, inp_shape[1]) - def generate_baselines_returns_array() -> ndarray: + def generate_baselines_returns_array() -> npt.NDArray: return np.arange(0.0, num_in * 4.0).reshape(4, num_in) # 10-class classification model diff --git a/tests/utils/models/linear_models/_test_linear_classifier.py b/tests/utils/models/linear_models/_test_linear_classifier.py index 39097ddd7..c144a394d 100644 --- a/tests/utils/models/linear_models/_test_linear_classifier.py +++ b/tests/utils/models/linear_models/_test_linear_classifier.py @@ -5,6 +5,7 @@ import captum._utils.models.linear_model.model as pytorch_model_module import numpy as np +import numpy.typing as npt import sklearn.datasets as datasets import torch from tests.helpers.evaluate_linear_model import evaluate @@ -107,7 +108,7 @@ def compare_to_sk_learn( o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1) rel_diff = cast( - np.ndarray, + npt.NDArray, # pyre-fixme[6]: For 1st argument expected `int` but got `Union[int, Tensor]`. (sum(o_sklearn.values()) - sum(o_pytorch.values())), ) / abs(sum(o_sklearn.values()))