diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 986f61c45..223428244 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -75,6 +75,7 @@ def _cumulative_sum_threshold( sorted_vals = np.sort(values.flatten()) cum_sums = np.cumsum(sorted_vals) threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] + # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`. return sorted_vals[threshold_id]