Skip to content

Commit

Permalink
FIX: rewritten problematic randomforest imputer that couldn't work wi…
Browse files Browse the repository at this point in the history
…th NAs xd
  • Loading branch information
Mikhail Lebedev committed May 17, 2024
1 parent 1c64535 commit 387867e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
24 changes: 5 additions & 19 deletions alphastats/DataSet_Preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sklearn.impute
import streamlit as st

from sklearn.experimental import enable_iterative_imputer
from sklearn.experimental import enable_iterative_imputer #noqa
from alphastats.utils import ignore_warning


Expand Down Expand Up @@ -140,25 +140,11 @@ def _imputation(self, method: str):
imputation_array = imp.fit_transform(self.mat.values)

elif method == "randomforest":
randomforest = sklearn.ensemble.RandomForestRegressor(
max_depth=10,
bootstrap=True,
max_samples=0.5,
n_jobs=2,
random_state=0,
verbose=0, #  random forest takes a while print progress
imp = sklearn.ensemble.HistGradientBoostingRegressor(
max_depth=10, max_iter=100, random_state=0
)
imp = sklearn.impute.IterativeImputer(
random_state=0, estimator=randomforest
)

# the random forest imputer doesnt work with float32/float16..
#  so the values are multiplied and converted to integers
array_multi_mio = self.mat.values * 1000000
array_int = array_multi_mio.astype("int")

imputation_array = imp.fit_transform(array_int)
imputation_array = imputation_array / 1000000
imp = sklearn.impute.IterativeImputer(random_state=0, estimator=imp)
imputation_array = imp.fit_transform(self.mat.values)

else:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def test_preprocess_imputation_randomforest_values(self):
self.obj.preprocess(log2_transform=False, imputation="randomforest")
expected_mat = pd.DataFrame(
{
"a": [2.00000000e00, 0, 4.00000000e00],
"b": [5.00000000e00, 4.00000000e00, 4.0],
"c": [0, 1.00000000e01, 0],
"a": [2.0, 3.0, 4.0],
"b": [5.0, 4.0, 4.0],
"c": [10.0, 10.0, 10.0],
}
)
pd._testing.assert_frame_equal(self.obj.mat, expected_mat)
Expand Down

0 comments on commit 387867e

Please sign in to comment.