-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sczbp feature importance and misc (#881)
<!-- Reviews go much faster if the reviewer knows what to focus on! Help them out, e.g.: Reviewers can skip X, but should pay attention to Y. -->
- Loading branch information
Showing
44 changed files
with
1,954 additions
and
339 deletions.
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
...configs/estimator_steps/miss_forest_imputation/miss_forest_imputation_20240416_132957.cfg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[placeholder] | ||
@estimator_steps = "miss_forest_imputation" |
2 changes: 2 additions & 0 deletions
2
...ical_registry_configs/estimator_steps/noop_imputation/noop_imputation_20240416_132957.cfg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[placeholder] | ||
@estimator_steps = "noop_imputation" |
3 changes: 3 additions & 0 deletions
3
..._registry_configs/estimator_steps/simple_imputation/simple_imputation_20240416_132957.cfg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[placeholder] | ||
@estimator_steps = "simple_imputation" | ||
strategy = "mean" |
3 changes: 3 additions & 0 deletions
3
.../estimator_steps_suggesters/imputation_suggester/imputation_suggester_20240416_132957.cfg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[placeholder] | ||
@estimator_steps_suggesters = "imputation_suggester" | ||
strategies = ["most_frequent", "mean"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
psycop/common/model_training_v2/trainer/task/estimator_steps/imputers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Literal | ||
|
||
import optuna | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.ensemble import RandomForestRegressor | ||
from sklearn.experimental import enable_iterative_imputer # noqa | ||
from sklearn.impute import IterativeImputer, SimpleImputer | ||
|
||
from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry | ||
from psycop.common.model_training_v2.hyperparameter_suggester.suggesters.base_suggester import ( | ||
Suggester, | ||
) | ||
from psycop.common.model_training_v2.hyperparameter_suggester.suggesters.suggester_spaces import ( | ||
CategoricalSpace, | ||
CategoricalSpaceT, | ||
) | ||
from psycop.common.model_training_v2.trainer.task.model_step import ModelStep | ||
|
||
|
||
class IdentityTransformer(BaseEstimator, TransformerMixin): | ||
def __init__(self): | ||
pass | ||
|
||
def fit(self, input_array, y=None): # type: ignore # noqa | ||
return self | ||
|
||
def transform(self, input_array, y=None): # type: ignore # noqa | ||
return input_array | ||
|
||
|
||
@BaselineRegistry.estimator_steps.register("noop_imputation") | ||
def noop_imputation_step() -> ModelStep: | ||
return ("imputer", IdentityTransformer()) | ||
|
||
|
||
@BaselineRegistry.estimator_steps.register("simple_imputation") | ||
def simple_imputation_step( | ||
strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", | ||
) -> ModelStep: | ||
return ("imputer", SimpleImputer(strategy=strategy)) | ||
|
||
|
||
@BaselineRegistry.estimator_steps.register("miss_forest_imputation") | ||
def miss_forest_imputation_step() -> ModelStep: | ||
"""Naive implementation of missforest using sklearn's IterativeImputer""" | ||
|
||
return ("imputer", IterativeImputer(estimator=RandomForestRegressor(), random_state=0)) | ||
|
||
|
||
IMPLEMENTED_STRATEGIES = ["mean", "median", "most_frequent", "miss_forest", "noop"] | ||
|
||
STRATEGY2STEP = { | ||
"mean": "simple_imputation", | ||
"median": "simple_imputation", | ||
"most_frequent": "simple_imputation", | ||
"miss_forest": "miss_forest_imputation", | ||
"noop": "noop_imputation", | ||
} | ||
|
||
|
||
@BaselineRegistry.estimator_steps_suggesters.register("imputation_suggester") | ||
class ImputationSuggester(Suggester): | ||
def __init__(self, strategies: CategoricalSpaceT): | ||
for strategy in strategies: | ||
if strategy not in IMPLEMENTED_STRATEGIES: | ||
raise ValueError(f"Imputation strategy {strategy} is not implemented") | ||
|
||
self.strategy = CategoricalSpace(choices=strategies) | ||
|
||
def suggest_hyperparameters(self, trial: optuna.Trial) -> dict[str, str]: | ||
strategy = self.strategy.suggest(trial, "imputation_strategy") | ||
estimator_step_str = STRATEGY2STEP[strategy] | ||
|
||
if strategy in ["miss_forest", "noop"]: | ||
return {"@estimator_steps": estimator_step_str} | ||
return {"@estimator_steps": estimator_step_str, "strategy": strategy} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file removed
0
psycop/projects/scz_bp/evaluation/model_performance/feature_importance/__init__.py
Empty file.
103 changes: 103 additions & 0 deletions
103
...jects/scz_bp/evaluation/model_performance/feature_importance/scz_bp_feature_importance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# type: ignore | ||
import pickle as pkl | ||
import re | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import polars as pl | ||
from sklearn.pipeline import Pipeline | ||
|
||
from psycop.common.global_utils.mlflow.mlflow_data_extraction import MlflowClientWrapper | ||
|
||
|
||
def scz_bp_parse_static_feature(full_string: str) -> str: | ||
"""Takes a static feature name and returns a human readable version of it.""" | ||
feature_name = full_string.replace("pred_", "") | ||
|
||
feature_capitalised = feature_name[0].upper() + feature_name[1:] | ||
|
||
manual_overrides = {"Age_in_years": "Age (years)"} | ||
|
||
if feature_capitalised in manual_overrides: | ||
feature_capitalised = manual_overrides[feature_capitalised] | ||
return feature_capitalised | ||
|
||
|
||
def scz_bp_parse_temporal_feature(full_string: str) -> str: | ||
feature_name = re.findall(r"pred_(.*)?_within", full_string)[0] | ||
if "_disorders" in feature_name: | ||
words = feature_name.split("_") | ||
words[0] = words[0].capitalize() | ||
feature_name = " ".join(word for word in words) | ||
|
||
lookbehind = re.findall(r"within_(.*)?_days", full_string)[0] | ||
resolve_multiple = re.findall(r"days_(.*)?_fallback", full_string)[0] | ||
|
||
remove = ["all_relevant_", "aktuelt_psykisk_", r"_layer_\d_*"] | ||
remove = "(%s)" % "|".join(remove) | ||
|
||
feature_name = re.sub(remove, "", feature_name) | ||
output_string = f"{feature_name} {lookbehind}-day {resolve_multiple} " | ||
return output_string | ||
|
||
|
||
def scz_bp_feature_name_to_readable(full_string: str) -> str: | ||
if "within" not in full_string: | ||
output_string = scz_bp_parse_static_feature(full_string) | ||
else: | ||
output_string = scz_bp_parse_temporal_feature(full_string=full_string) | ||
return output_string | ||
|
||
|
||
def scz_bp_generate_feature_importance_table( | ||
pipeline: Pipeline, clf_model_name: str = "classifier" | ||
) -> pd.DataFrame: | ||
# Get feature importance scores | ||
feature_importances = pipeline.named_steps[clf_model_name].feature_importances_ | ||
|
||
if hasattr(pipeline.named_steps[clf_model_name], "feature_names"): | ||
selected_feature_names = pipeline.named_steps[clf_model_name].feature_names | ||
elif hasattr(pipeline.named_steps[clf_model_name], "feature_name_"): | ||
selected_feature_names = pipeline.named_steps[clf_model_name].feature_name_ | ||
elif hasattr(pipeline.named_steps[clf_model_name], "feature_names_in_"): | ||
selected_feature_names = pipeline.named_steps[clf_model_name].feature_names_in_ | ||
else: | ||
raise ValueError("The classifier does not implement .feature_names or .feature_name_") | ||
|
||
# Create a DataFrame to store the feature names and their corresponding gain | ||
feature_table = pl.DataFrame( | ||
{"Feature Name": selected_feature_names, "Feature Importance": feature_importances} | ||
) | ||
|
||
# Sort the table by gain in descending order | ||
feature_table = feature_table.sort("Feature Importance", descending=True) | ||
# Get the top 100 features by gain | ||
top_100_features = feature_table.head(100).with_columns( | ||
# pl.col("Feature Importance").round(3), # noqa: ERA001 | ||
pl.col("Feature Name").apply(lambda x: scz_bp_feature_name_to_readable(x)) | ||
) | ||
|
||
pd_df = top_100_features.to_pandas() | ||
pd_df = pd_df.reset_index() | ||
pd_df["index"] = pd_df["index"] + 1 | ||
pd_df = pd_df.set_index("index") | ||
|
||
return pd_df | ||
|
||
|
||
if __name__ == "__main__": | ||
best_experiment = "sczbp/structured_text_xgboost_ddpm_3x_positives" | ||
best_run = MlflowClientWrapper().get_best_run_from_experiment( | ||
experiment_name=best_experiment, metric="all_oof_BinaryAUROC" | ||
) | ||
|
||
with best_run.download_artifact("sklearn_pipe.pkl").open("rb") as pipe_pkl: | ||
pipe = pkl.load(pipe_pkl) | ||
|
||
feat_imp = scz_bp_generate_feature_importance_table(pipeline=pipe, clf_model_name="classifier") | ||
pl.Config.set_tbl_rows(100) | ||
|
||
with (Path(__file__).parent / f"feat_imp_100_{best_experiment.split('/')[1]}.html").open( | ||
"w" | ||
) as html_file: | ||
html_file.write(feat_imp.to_html()) |
39 changes: 0 additions & 39 deletions
39
psycop/projects/scz_bp/evaluation/model_performance/feature_importance/scz_bp_gain.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.