diff --git a/alphastats/DataSet.py b/alphastats/DataSet.py index e90a050b..98ea4a40 100644 --- a/alphastats/DataSet.py +++ b/alphastats/DataSet.py @@ -34,7 +34,7 @@ plotly.io.templates.default = "simple_white+alphastats_colors" -class DataSet(Preprocess, Statistics, Plot, Enrichment): +class DataSet(Statistics, Plot, Enrichment): """Analysis Object""" def __init__(self, loader, metadata_path=None, sample_column=None): @@ -83,6 +83,25 @@ def __init__(self, loader, metadata_path=None, sample_column=None): print("DataSet has been created.") self.overview() + def preprocess(self, **kwargs): + pp = Preprocess( + self.filter_columns, + self.rawinput, + self.index_column, + self.sample, + self.metadata, + self.preprocessing_info, + self.mat, + ) + + self.mat, self.metadata, self.preprocessing_info = pp.preprocess(**kwargs) + self.preprocessed = True + + def reset_preprocessing(self): + """Reset all preprocessing steps""" + self.create_matrix() + print("All preprocessing steps are reset.") + def _create_metadata(self): samples = list(self.mat.index) self.metadata = pd.DataFrame({"sample": samples}) diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index d6cf9ed3..970a066a 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -1,4 +1,3 @@ -import itertools import logging import numpy as np @@ -9,7 +8,6 @@ import streamlit as st from sklearn.experimental import enable_iterative_imputer # noqa -from tqdm import tqdm from alphastats.utils import ignore_warning @@ -18,7 +16,27 @@ class Preprocess: imputation_methods = ["mean", "median", "knn", "randomforest"] normalization_methods = ["vst", "zscore", "quantile"] - def _remove_sampels(self, sample_list: list): + def __init__( + self, + filter_columns, + rawinput, + index_column, + sample, + metadata, + preprocessing_info, + mat, + ): + self.filter_columns = filter_columns + + self.rawinput = rawinput + self.index_column = index_column + self.sample = sample + + self.metadata = metadata # changed + self.preprocessing_info = preprocessing_info # changed + self.mat = mat # changed + + def _remove_samples(self, sample_list: list): # exclude samples for analysis self.mat = self.mat.drop(sample_list) self.metadata = self.metadata[~self.metadata[self.sample].isin(sample_list)] @@ -31,10 +49,6 @@ def _subset(self): ) return self.mat[self.mat.index.isin(self.metadata[self.sample].tolist())] - def preprocess_print_info(self): - """Print summary of preprocessing steps""" - print(pd.DataFrame(self.preprocessing_info.items())) - def _remove_na_values(self, cut_off): if ( self.preprocessing_info.get("Missing values were removed") @@ -215,42 +229,6 @@ def _normalization(self, method: str): self.preprocessing_info.update({"Normalization": method}) - def reset_preprocessing(self): - """Reset all preprocessing steps""" - self.create_matrix() - print("All preprocessing steps are reset.") - - @ignore_warning(RuntimeWarning) - def _compare_preprocessing_modes(self, func, params_for_func) -> list: - dataset = self - - preprocessing_modes = list( - itertools.product(self.normalization_methods, self.imputation_methods) - ) - - results_list = [] - - del params_for_func["compare_preprocessing_modes"] - params_for_func["dataset"] = params_for_func.pop("self") - - # TODO: make this progress transparent in GUI - for preprocessing_mode in tqdm(preprocessing_modes): - # reset preprocessing - dataset.reset_preprocessing() - print( - f"Normalization {preprocessing_mode[0]}, Imputation {str(preprocessing_mode[1])}" - ) - dataset.mat.replace([np.inf, -np.inf], np.nan, inplace=True) - - dataset.preprocess( - subset=True, - normalization=preprocessing_mode[0], - imputation=preprocessing_mode[1], - ) - - res = func(**params_for_func) - results_list.append(res) - # TODO this needs to be reimplemented # @ignore_warning(RuntimeWarning) # def _compare_preprocessing_modes(self, func, params_for_func) -> list: @@ -370,7 +348,7 @@ def preprocess( self._filter() if remove_samples is not None: - self._remove_sampels(sample_list=remove_samples) + self._remove_samples(sample_list=remove_samples) if subset: self.mat = self._subset() @@ -394,4 +372,5 @@ def preprocess( "Matrix: Number of ProteinIDs/ProteinGroups": self.mat.shape[1], } ) - self.preprocessed = True + + return self.mat, self.metadata, self.preprocessing_info