Skip to content

Commit

Permalink
fix: 🐛 fix shap imp calculation for multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Bury committed Apr 30, 2024
1 parent 9fa0bf9 commit f69899c
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/arfs/feature_selection/variable_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,21 @@ def _compute_varimp_lgb(
# X_SHAP_values (array-like of shape = [n_samples, n_features + 1]
# or shape = [n_samples, (n_features + 1) * n_classes])
# index starts from 0
n_feat = gbm_model.valid_features.shape[1]
n_features_plus_bias = gbm_model.valid_features.shape[1] + 1
n_samples = gbm_model.valid_features.shape[0]
y_freq_table = pd.Series(y.fillna(0)).value_counts(normalize=True)
n_classes = y_freq_table.size
shap_matrix = np.delete(
shap_matrix,
list(range(n_feat, (n_feat + 1) * n_classes, n_feat + 1)),
axis=1,
)
shap_imp = np.mean(np.abs(shap_matrix[:, :-1]), axis=0)

# Reshape the array to [n_samples, n_features + 1, n_classes]
reshaped_values = shap_matrix.reshape(n_samples, n_classes, n_features_plus_bias)

# Since we need (n_samples, n_features + 1, n_classes), transpose the second and third dimensions
reshaped_values = reshaped_values.transpose(0, 2, 1)
reshaped_values = reshaped_values[:, :-1, :]
reshaped_values.shape
# Sum the contributions for each class ignoring the bias term
# average on all the samples
shap_imp = np.abs(reshaped_values).sum(axis=-1).mean(axis=0)
else:
# for binary, only one class is returned, for regression a single column added as well
shap_imp = np.mean(np.abs(shap_matrix[:, :-1]), axis=0)
Expand Down

0 comments on commit f69899c

Please sign in to comment.