diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index e98282cc9..986f61c45 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -109,6 +109,28 @@ def _normalize_attr( return _normalize_scale(attr_combined, threshold) +def _create_default_plot( + # pyre-fixme[2]: Parameter must be annotated. + plt_fig_axis, + # pyre-fixme[2]: Parameter must be annotated. + use_pyplot, + # pyre-fixme[2]: Parameter must be annotated. + fig_size, + **pyplot_kwargs: Any, +) -> Tuple[Figure, Axes]: + # Create plot if figure, axis not provided + if plt_fig_axis is not None: + plt_fig, plt_axis = plt_fig_axis + else: + if use_pyplot: + plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs) + else: + plt_fig = Figure(figsize=fig_size) + plt_axis = plt_fig.subplots(**pyplot_kwargs) + return plt_fig, plt_axis + # Figure.subplots returns Axes or array of Axes + + def _initialize_cmap_and_vmin_vmax( sign: str, ) -> Tuple[Union[str, Colormap], float, float]: @@ -338,16 +360,7 @@ def visualize_image_attr( >>> # Displays blended heat map visualization of computed attributions. >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map") """ - # Create plot if figure, axis not provided - if plt_fig_axis is not None: - plt_fig, plt_axis = plt_fig_axis - else: - if use_pyplot: - plt_fig, plt_axis = plt.subplots(figsize=fig_size) - else: - plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots() - # Figure.subplots returns Axes or array of Axes + plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size) if original_image is not None: if np.max(original_image) <= 1.0: @@ -362,8 +375,10 @@ def visualize_image_attr( ) # Remove ticks and tick labels from plot. - plt_axis.xaxis.set_ticks_position("none") - plt_axis.yaxis.set_ticks_position("none") + if plt_axis.xaxis is not None: + plt_axis.xaxis.set_ticks_position("none") + if plt_axis.yaxis is not None: + plt_axis.yaxis.set_ticks_position("none") plt_axis.set_yticklabels([]) plt_axis.set_xticklabels([]) plt_axis.grid(visible=False) @@ -528,6 +543,161 @@ def visualize_image_attr_multiple( return plt_fig, plt_axis +def _plot_attrs_as_axvspan( + # pyre-fixme[2]: Parameter must be annotated. + attr_vals, + # pyre-fixme[2]: Parameter must be annotated. + x_vals, + # pyre-fixme[2]: Parameter must be annotated. + ax, + # pyre-fixme[2]: Parameter must be annotated. + x_values, + # pyre-fixme[2]: Parameter must be annotated. + cmap, + # pyre-fixme[2]: Parameter must be annotated. + cm_norm, + # pyre-fixme[2]: Parameter must be annotated. + alpha_overlay, +) -> None: + # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. + half_col_width = (x_values[1] - x_values[0]) / 2.0 + + for icol, col_center in enumerate(x_vals): + left = col_center - half_col_width + right = col_center + half_col_width + ax.axvspan( + xmin=left, + xmax=right, + # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function. + facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore + edgecolor=None, + alpha=alpha_overlay, + ) + + +def _visualize_overlay_individual( + # pyre-fixme[2]: Parameter must be annotated. + num_channels, + # pyre-fixme[2]: Parameter must be annotated. + plt_axis_list, + # pyre-fixme[2]: Parameter must be annotated. + x_values, + # pyre-fixme[2]: Parameter must be annotated. + data, + # pyre-fixme[2]: Parameter must be annotated. + channel_labels, + # pyre-fixme[2]: Parameter must be annotated. + norm_attr, + # pyre-fixme[2]: Parameter must be annotated. + cmap, + # pyre-fixme[2]: Parameter must be annotated. + cm_norm, + # pyre-fixme[2]: Parameter must be annotated. + alpha_overlay, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs: Any, +) -> None: + # helper method for visualize_timeseries_attr + pyplot_kwargs = kwargs.get("pyplot_kwargs", {}) + for chan in range(num_channels): + plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs) + if channel_labels is not None: + plt_axis_list[chan].set_ylabel(channel_labels[chan]) + + _plot_attrs_as_axvspan( + norm_attr[chan], + x_values, + plt_axis_list[chan], + x_values, + cmap, + cm_norm, + alpha_overlay, + ) + + plt.subplots_adjust(hspace=0) + pass + + +def _visualize_overlay_combined( + # pyre-fixme[2]: Parameter must be annotated. + num_channels, + # pyre-fixme[2]: Parameter must be annotated. + plt_axis_list, + # pyre-fixme[2]: Parameter must be annotated. + x_values, + # pyre-fixme[2]: Parameter must be annotated. + data, + # pyre-fixme[2]: Parameter must be annotated. + channel_labels, + # pyre-fixme[2]: Parameter must be annotated. + norm_attr, + # pyre-fixme[2]: Parameter must be annotated. + cmap, + # pyre-fixme[2]: Parameter must be annotated. + cm_norm, + # pyre-fixme[2]: Parameter must be annotated. + alpha_overlay, + **kwargs: Any, +) -> None: + pyplot_kwargs = kwargs.get("pyplot_kwargs", {}) + + cycler = plt.cycler("color", matplotlib.colormaps["Dark2"].colors) # type: ignore + plt_axis_list[0].set_prop_cycle(cycler) + + for chan in range(num_channels): + label = channel_labels[chan] if channel_labels else None + plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs) + + _plot_attrs_as_axvspan( + norm_attr, + x_values, + plt_axis_list[0], + x_values, + cmap, + cm_norm, + alpha_overlay, + ) + + plt_axis_list[0].legend(loc="best") + + +def _visualize_colored_graph( + # pyre-fixme[2]: Parameter must be annotated. + num_channels, + # pyre-fixme[2]: Parameter must be annotated. + plt_axis_list, + # pyre-fixme[2]: Parameter must be annotated. + x_values, + # pyre-fixme[2]: Parameter must be annotated. + data, + # pyre-fixme[2]: Parameter must be annotated. + channel_labels, + # pyre-fixme[2]: Parameter must be annotated. + norm_attr, + # pyre-fixme[2]: Parameter must be annotated. + cmap, + # pyre-fixme[2]: Parameter must be annotated. + cm_norm, + **kwargs: Any, +) -> None: + # helper method for visualize_timeseries_attr + pyplot_kwargs = kwargs.get("pyplot_kwargs", {}) + for chan in range(num_channels): + points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs) + lc.set_array(norm_attr[chan, :]) + plt_axis_list[chan].add_collection(lc) + plt_axis_list[chan].set_ylim( + 1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :]) + ) + if channel_labels is not None: + plt_axis_list[chan].set_ylabel(channel_labels[chan]) + + plt.subplots_adjust(hspace=0) + + def visualize_timeseries_attr( attr: npt.NDArray, data: npt.NDArray, @@ -686,8 +856,8 @@ def visualize_timeseries_attr( num_subplots = num_channels if ( - TimeseriesVisualizationMethod[method] - == TimeseriesVisualizationMethod.overlay_combined + TimeseriesVisualizationMethod[method].value + == TimeseriesVisualizationMethod.overlay_combined.value ): num_subplots = 1 attr = np.sum(attr, axis=0) # Merge attributions across channels @@ -700,17 +870,9 @@ def visualize_timeseries_attr( x_values = np.arange(timeseries_length) # Create plot if figure, axis not provided - if plt_fig_axis is not None: - plt_fig, plt_axis = plt_fig_axis - else: - if use_pyplot: - plt_fig, plt_axis = plt.subplots( # type: ignore - figsize=fig_size, nrows=num_subplots, sharex=True - ) - else: - plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore - # Figure.subplots returns Axes or array of Axes + plt_fig, plt_axis = _create_default_plot( + plt_fig_axis, use_pyplot, fig_size, nrows=num_subplots, sharex=True + ) if not isinstance(plt_axis, ndarray): plt_axis_list = np.array([plt_axis]) @@ -720,91 +882,30 @@ def visualize_timeseries_attr( norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None) # Set default colormap and bounds based on sign. - if VisualizeSign[sign] == VisualizeSign.all: - default_cmap: Union[str, LinearSegmentedColormap] = ( - LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"]) - ) - vmin, vmax = -1, 1 - elif VisualizeSign[sign] == VisualizeSign.positive: - default_cmap = "Greens" - vmin, vmax = 0, 1 - elif VisualizeSign[sign] == VisualizeSign.negative: - default_cmap = "Reds" - vmin, vmax = 0, 1 - elif VisualizeSign[sign] == VisualizeSign.absolute_value: - default_cmap = "Blues" - vmin, vmax = 0, 1 - else: - raise AssertionError("Visualize Sign type is not valid.") + default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign) cmap = cmap if cmap is not None else default_cmap cmap = cm.get_cmap(cmap) # type: ignore cm_norm = colors.Normalize(vmin, vmax) - # pyre-fixme[53]: Captured variable `cm_norm` is not annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _plot_attrs_as_axvspan(attr_vals, x_vals, ax) -> None: - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - half_col_width = (x_values[1] - x_values[0]) / 2.0 - for icol, col_center in enumerate(x_vals): - left = col_center - half_col_width - right = col_center + half_col_width - ax.axvspan( - xmin=left, - xmax=right, - # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function. - facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore - edgecolor=None, - alpha=alpha_overlay, - ) - - if ( - TimeseriesVisualizationMethod[method] - == TimeseriesVisualizationMethod.overlay_individual - ): - for chan in range(num_channels): - plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs) - if channel_labels is not None: - plt_axis_list[chan].set_ylabel(channel_labels[chan]) - - _plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan]) - - plt.subplots_adjust(hspace=0) - - elif ( - TimeseriesVisualizationMethod[method] - == TimeseriesVisualizationMethod.overlay_combined - ): - # Dark colors are better in this case - cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore - plt_axis_list[0].set_prop_cycle(cycler) - - for chan in range(num_channels): - label = channel_labels[chan] if channel_labels else None - plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs) - - _plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0]) - - plt_axis_list[0].legend(loc="best") - - elif ( - TimeseriesVisualizationMethod[method] - == TimeseriesVisualizationMethod.colored_graph - ): - for chan in range(num_channels): - points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2) - segments = np.concatenate([points[:-1], points[1:]], axis=1) - - lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs) - lc.set_array(norm_attr[chan, :]) - plt_axis_list[chan].add_collection(lc) - plt_axis_list[chan].set_ylim( - 1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :]) - ) - if channel_labels is not None: - plt_axis_list[chan].set_ylabel(channel_labels[chan]) - - plt.subplots_adjust(hspace=0) - + visualization_methods: Dict[str, Callable[..., Union[None, AxesImage]]] = { + "overlay_individual": _visualize_overlay_individual, + "overlay_combined": _visualize_overlay_combined, + "colored_graph": _visualize_colored_graph, + } + kwargs = { + "num_channels": num_channels, + "plt_axis_list": plt_axis_list, + "x_values": x_values, + "data": data, + "channel_labels": channel_labels, + "norm_attr": norm_attr, + "cmap": cmap, + "cm_norm": cm_norm, + "alpha_overlay": alpha_overlay, + "pyplot_kwargs": pyplot_kwargs, + } + if method in visualization_methods: + visualization_methods[method](**kwargs) else: raise AssertionError("Invalid visualization method: {}".format(method))