diff --git a/modnet/preprocessing.py b/modnet/preprocessing.py index f7690d0..6160b23 100644 --- a/modnet/preprocessing.py +++ b/modnet/preprocessing.py @@ -797,6 +797,7 @@ def feature_selection( drop_thr: float = 0.2, n_jobs: int = None, ignore_names: Optional[List] = [], + random_state: int = None, ): """Compute the mutual information between features and targets, then apply relevance-redundancy rankings to choose the top `n` @@ -859,7 +860,11 @@ def feature_selection( else: df = self.df_featurized.copy() self.cross_nmi, self.feature_entropy = get_cross_nmi( - df, return_entropy=True, drop_thr=drop_thr, n_jobs=n_jobs + df, + return_entropy=True, + drop_thr=drop_thr, + n_jobs=n_jobs, + random_state=random_state, ) if self.cross_nmi.isna().sum().sum() > 0: @@ -889,6 +894,7 @@ def feature_selection( df, df_target, task_type, + random_state=random_state, )[name] LOG.info("Computing optimal features...")