diff --git a/pipit/trace.py b/pipit/trace.py index fcc9f75..7594693 100644 --- a/pipit/trace.py +++ b/pipit/trace.py @@ -585,10 +585,14 @@ def load_imbalance(self, metric="time.exc", num_processes=1): for function in functions: curr_series = flat_profile.loc[function] - top_n = curr_series.sort_values(ascending=False).iloc[0:num_display] + top_n = curr_series.sort_values(by=metric, ascending=False).iloc[ + 0:num_display + ] - imbalance_dict[mean_metric].append(curr_series.mean()) - imbalance_dict[imb_metric].append(top_n.values[0] / curr_series.mean()) + imbalance_dict[mean_metric].append(curr_series.mean().values[0]) + imbalance_dict[imb_metric].append( + (top_n.values[0] / curr_series.mean()).values[0] + ) imbalance_dict[imb_ranks].append(list(top_n.index)) imbalance_df = pd.DataFrame(imbalance_dict)