Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #319 from Aarhus-Psychiatry-Research/bokajgd/issue312
Browse files Browse the repository at this point in the history
feat: adding mutual information based feature selection
  • Loading branch information
bokajgd authored Nov 14, 2022
2 parents 81296b1 + 0976ee1 commit 46fc7f0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/psycopt2d/config/preprocessing/default_preprocessing.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
convert_to_boolean: false
convert_datetimes_to_ordinal: false
imputation_method: "most_frequent" # (str): Options include 2most_frequent"
imputation_method: "most_frequent" # (str): Options include "most_frequent"
transform: null # (str|null): Transformation applied to all predictors after imputation. Options include "z-score-normalization"
feature_selection:
name: null
feature_selection:
name: null # (str) Options include "f_classif", "chi2" and "mutual_info_classif". Default to use is "mutual_info_classif".
params:
percentile: 10 # (int): Percent of features to keep. Defaults to 10.
17 changes: 16 additions & 1 deletion src/psycopt2d/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import wandb
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from sklearn.feature_selection import SelectPercentile, chi2, f_classif
from sklearn.feature_selection import (
SelectPercentile,
chi2,
f_classif,
mutual_info_classif,
)
from sklearn.impute import SimpleImputer
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedGroupKFold
Expand Down Expand Up @@ -90,6 +95,16 @@ def create_preprocessing_pipeline(cfg: FullConfigSchema):
),
),
)
if cfg.preprocessing.feature_selection.name == "mutual_info_classif":
steps.append(
(
"feature_selection",
SelectPercentile(
mutual_info_classif,
percentile=cfg.preprocessing.feature_selection.params["percentile"],
),
),
)

return Pipeline(steps)

Expand Down

0 comments on commit 46fc7f0

Please sign in to comment.