From f69899c8a47285b70a8408179adce69ae1cfa045 Mon Sep 17 00:00:00 2001 From: Thomas Bury Date: Tue, 30 Apr 2024 12:19:35 +0200 Subject: [PATCH] fix: :bug: fix shap imp calculation for multiclass --- .../feature_selection/variable_importance.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/arfs/feature_selection/variable_importance.py b/src/arfs/feature_selection/variable_importance.py index 82b3183..f995c42 100644 --- a/src/arfs/feature_selection/variable_importance.py +++ b/src/arfs/feature_selection/variable_importance.py @@ -398,15 +398,21 @@ def _compute_varimp_lgb( # X_SHAP_values (array-like of shape = [n_samples, n_features + 1] # or shape = [n_samples, (n_features + 1) * n_classes]) # index starts from 0 - n_feat = gbm_model.valid_features.shape[1] + n_features_plus_bias = gbm_model.valid_features.shape[1] + 1 + n_samples = gbm_model.valid_features.shape[0] y_freq_table = pd.Series(y.fillna(0)).value_counts(normalize=True) n_classes = y_freq_table.size - shap_matrix = np.delete( - shap_matrix, - list(range(n_feat, (n_feat + 1) * n_classes, n_feat + 1)), - axis=1, - ) - shap_imp = np.mean(np.abs(shap_matrix[:, :-1]), axis=0) + + # Reshape the array to [n_samples, n_features + 1, n_classes] + reshaped_values = shap_matrix.reshape(n_samples, n_classes, n_features_plus_bias) + + # Since we need (n_samples, n_features + 1, n_classes), transpose the second and third dimensions + reshaped_values = reshaped_values.transpose(0, 2, 1) + reshaped_values = reshaped_values[:, :-1, :] + reshaped_values.shape + # Sum the contributions for each class ignoring the bias term + # average on all the samples + shap_imp = np.abs(reshaped_values).sum(axis=-1).mean(axis=0) else: # for binary, only one class is returned, for regression a single column added as well shap_imp = np.mean(np.abs(shap_matrix[:, :-1]), axis=0)