diff --git a/networkcommons/utils.py b/networkcommons/utils.py index 75f97b8..fa095f4 100644 --- a/networkcommons/utils.py +++ b/networkcommons/utils.py @@ -229,14 +229,17 @@ def targetlayer_formatter(df, n_elements=25, act_col='stat'): return dict_df -def handle_missing_values(df, threshold=0.1, fill=True): +def handle_missing_values(df, threshold=0.1, fill=np.mean): """ - Handles missing values in a DataFrame by filling them with the mean of the row or dropping the rows. + Handles missing values in a DataFrame by filling them with a specified function or value, or dropping the rows. Parameters: - df (pandas.DataFrame): The DataFrame containing the data. - threshold (float): The threshold for the share (0>> df = pd.DataFrame({'A': [1, 2, np.nan], 'B': [3, 2, np.nan], 'C': [np.nan, 7, 8]}) - >>> handle_missing_values(df, 0.5) + >>> handle_missing_values(df, 0.5, fill=np.mean) Number of genes filled: 1 Number of genes removed: 1 """ @@ -270,10 +273,17 @@ def handle_missing_values(df, threshold=0.1, fill=True): filled_count = (df[to_fill].isna().sum(axis=1) > 0).sum() - # Replace NAs with the mean of the row for rows to fill - if fill: - df.loc[to_fill] = df.loc[to_fill].apply(lambda row: row.fillna(row.mean()), axis=1) - print(f"Number of genes filled: {filled_count}") + # Replace NAs based on the fill argument + if callable(fill): + # If fill is a function (like np.mean, np.median), apply it row-wise + df.loc[to_fill] = df.loc[to_fill].apply(lambda row: row.fillna(fill(row)), axis=1) + print(f"Number of genes filled using function {fill.__name__}: {filled_count}") + elif isinstance(fill, (int, float)): + # If fill is a constant (like 0), use it directly + df.loc[to_fill] = df.loc[to_fill].fillna(fill) + print(f"Number of genes filled with value {fill}: {filled_count}") + elif fill is not None: + raise ValueError("fill parameter must be a callable, a numeric value, or None") # Drop rows with NA percentage greater than or equal to threshold df = df[~to_drop] diff --git a/tests/test_utils.py b/tests/test_utils.py index 31f823f..b9cad19 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -173,21 +173,21 @@ def test_subset_df_with_nodes(): def test_handle_missing_values_fill(): df = pd.DataFrame({'A': [1, 2, np.nan], 'B': [3, 2, np.nan], 'C': [np.nan, 7, 8]}) - result = utils.handle_missing_values(df, 0.5, fill=True) + result = utils.handle_missing_values(df, 0.5, fill=np.mean) expected = pd.DataFrame({'index': [0, 1], 'A': [1.0, 2.0], 'B': [3.0, 2.0], 'C': [2.0, 7.0]}).astype({'index': 'int64'}) pd.testing.assert_frame_equal(result, expected) def test_handle_missing_values_fill_and_drop(): df = pd.DataFrame({'A': [1, np.nan, np.nan], 'B': [np.nan, 2, np.nan], 'C': [np.nan, 7, np.nan]}) - result = utils.handle_missing_values(df, 0.5, fill=True) + result = utils.handle_missing_values(df, 0.5, fill=np.mean) expected = pd.DataFrame({'index': [1], 'A': [4.5], 'B': [2.0], 'C': [7.0]}).astype({'index': 'int64'}) pd.testing.assert_frame_equal(result, expected) def test_handle_missing_values_drop(): df = pd.DataFrame({'A': [1, np.nan, np.nan], 'B': [np.nan, np.nan, np.nan], 'C': [np.nan, np.nan, 8]}) - result = utils.handle_missing_values(df, 0.1, fill=False) + result = utils.handle_missing_values(df, 0.1, fill=None) expected = pd.DataFrame({'index': [], 'A': [], 'B': [], 'C': []}).astype({'index': 'int64'}) pd.testing.assert_frame_equal(result, expected)