Skip to content

Commit

Permalink
make Preprocessing a dedicated class
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwoer committed Sep 12, 2024
1 parent fce7052 commit 9c52689
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 46 deletions.
21 changes: 20 additions & 1 deletion alphastats/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
plotly.io.templates.default = "simple_white+alphastats_colors"


class DataSet(Preprocess, Statistics, Plot, Enrichment):
class DataSet(Statistics, Plot, Enrichment):
"""Analysis Object"""

def __init__(self, loader, metadata_path=None, sample_column=None):
Expand Down Expand Up @@ -83,6 +83,25 @@ def __init__(self, loader, metadata_path=None, sample_column=None):
print("DataSet has been created.")
self.overview()

def preprocess(self, **kwargs):
pp = Preprocess(
self.filter_columns,
self.rawinput,
self.index_column,
self.sample,
self.metadata,
self.preprocessing_info,
self.mat,
)

self.mat, self.metadata, self.preprocessing_info = pp.preprocess(**kwargs)
self.preprocessed = True

def reset_preprocessing(self):
"""Reset all preprocessing steps"""
self.create_matrix()
print("All preprocessing steps are reset.")

def _create_metadata(self):
samples = list(self.mat.index)
self.metadata = pd.DataFrame({"sample": samples})
Expand Down
69 changes: 24 additions & 45 deletions alphastats/DataSet_Preprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import logging

import numpy as np
Expand All @@ -9,7 +8,6 @@
import streamlit as st

from sklearn.experimental import enable_iterative_imputer # noqa
from tqdm import tqdm

from alphastats.utils import ignore_warning

Expand All @@ -18,7 +16,27 @@ class Preprocess:
imputation_methods = ["mean", "median", "knn", "randomforest"]
normalization_methods = ["vst", "zscore", "quantile"]

def _remove_sampels(self, sample_list: list):
def __init__(
self,
filter_columns,
rawinput,
index_column,
sample,
metadata,
preprocessing_info,
mat,
):
self.filter_columns = filter_columns

self.rawinput = rawinput
self.index_column = index_column
self.sample = sample

self.metadata = metadata # changed
self.preprocessing_info = preprocessing_info # changed
self.mat = mat # changed

def _remove_samples(self, sample_list: list):
# exclude samples for analysis
self.mat = self.mat.drop(sample_list)
self.metadata = self.metadata[~self.metadata[self.sample].isin(sample_list)]
Expand All @@ -31,10 +49,6 @@ def _subset(self):
)
return self.mat[self.mat.index.isin(self.metadata[self.sample].tolist())]

def preprocess_print_info(self):
"""Print summary of preprocessing steps"""
print(pd.DataFrame(self.preprocessing_info.items()))

def _remove_na_values(self, cut_off):
if (
self.preprocessing_info.get("Missing values were removed")
Expand Down Expand Up @@ -215,42 +229,6 @@ def _normalization(self, method: str):

self.preprocessing_info.update({"Normalization": method})

def reset_preprocessing(self):
"""Reset all preprocessing steps"""
self.create_matrix()
print("All preprocessing steps are reset.")

@ignore_warning(RuntimeWarning)
def _compare_preprocessing_modes(self, func, params_for_func) -> list:
dataset = self

preprocessing_modes = list(
itertools.product(self.normalization_methods, self.imputation_methods)
)

results_list = []

del params_for_func["compare_preprocessing_modes"]
params_for_func["dataset"] = params_for_func.pop("self")

# TODO: make this progress transparent in GUI
for preprocessing_mode in tqdm(preprocessing_modes):
# reset preprocessing
dataset.reset_preprocessing()
print(
f"Normalization {preprocessing_mode[0]}, Imputation {str(preprocessing_mode[1])}"
)
dataset.mat.replace([np.inf, -np.inf], np.nan, inplace=True)

dataset.preprocess(
subset=True,
normalization=preprocessing_mode[0],
imputation=preprocessing_mode[1],
)

res = func(**params_for_func)
results_list.append(res)

# TODO this needs to be reimplemented
# @ignore_warning(RuntimeWarning)
# def _compare_preprocessing_modes(self, func, params_for_func) -> list:
Expand Down Expand Up @@ -370,7 +348,7 @@ def preprocess(
self._filter()

if remove_samples is not None:
self._remove_sampels(sample_list=remove_samples)
self._remove_samples(sample_list=remove_samples)

if subset:
self.mat = self._subset()
Expand All @@ -394,4 +372,5 @@ def preprocess(
"Matrix: Number of ProteinIDs/ProteinGroups": self.mat.shape[1],
}
)
self.preprocessed = True

return self.mat, self.metadata, self.preprocessing_info

0 comments on commit 9c52689

Please sign in to comment.