Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor analysis ii #367

Open
wants to merge 8 commits into
base: refactor_analysis_I
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import streamlit as st

from alphastats.gui.utils.analysis import PlottingOptions
from alphastats.gui.utils.analysis_helper import (
display_df,
display_plot,
do_analysis,
)
from alphastats.gui.utils.options import get_plotting_options, get_statistic_options
from alphastats.gui.utils.options import (
get_plotting_options,
get_statistic_options,
)
from alphastats.gui.utils.ui_helper import (
StateKeys,
convert_df,
Expand Down Expand Up @@ -44,18 +48,27 @@

c1, c2 = st.columns([0.33, 0.67])
with c1:
plotting_options = PlottingOptions.get_values()
method = st.selectbox(
"Analysis",
options=["<select>"]
+ ["------- plots -------"]
+ list(get_plotting_options(st.session_state).keys())
+ plotting_options
+ [
key
for key in list(get_plotting_options(st.session_state).keys())
if key not in plotting_options
]
+ ["------- statistics -------"]
+ list(get_statistic_options(st.session_state).keys()),
)

if method in (plotting_options := get_plotting_options(st.session_state)):
if method in (
list((plot_options := get_plotting_options(st.session_state)).keys())
+ plotting_options
):
analysis_result, analysis_object, parameters = do_analysis(
method, options_dict=plotting_options
method, options_dict=plot_options
)
show_plot = analysis_result is not None

Expand Down
91 changes: 91 additions & 0 deletions alphastats/gui/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from alphastats.plots.VolcanoPlot import VolcanoPlot


class PlottingOptions:
"""Keys for the plotting options."""

PCA_PLOT = "PCA Plot"
UMAP_PLOT = "UMAP Plot"
TSNE_PLOT = "t-SNE Plot"
VOLCANO_PLOT = "Volcano Plot"

@classmethod
def get_values(cls):
"""Get all user-defined string values of the class."""
return [
value
for key, value in cls.__dict__.items()
if not key.startswith("__") and isinstance(value, str)
]


class Analysis(ABC):
"""Abstract class for analysis widgets."""

Expand Down Expand Up @@ -89,6 +107,79 @@ def show_widget(self):
self._parameters["column"] = column


class DimensionReductionAnalysis(Analysis, ABC):
"""Abstract class for dimension reduction analysis widgets."""

def show_widget(self):
"""Gather parameters for dimension reduction analysis."""

group = st.selectbox(
"Color according to",
options=[None] + self._dataset.metadata.columns.to_list(),
)

circle = st.checkbox("circle")

self._parameters.update({"circle": circle, "group": group})


class PCAPlotAnalysis(DimensionReductionAnalysis):
"""Widget for PCA Plot analysis."""

def do_analysis(self):
"""Draw PCA Plot using the PCAPlot class."""

pca_plot = self._dataset.plot_pca(
group=self._parameters["group"],
circle=self._parameters["circle"],
)
return pca_plot, None, self._parameters


class UMAPPlotAnalysis(DimensionReductionAnalysis):
"""Widget for UMAP Plot analysis."""

def do_analysis(self):
"""Draw PCA Plot using the PCAPlot class."""
umap_plot = self._dataset.plot_umap(
group=self._parameters["group"],
circle=self._parameters["circle"],
)
return umap_plot, None, self._parameters


class TSNEPlotAnalysis(DimensionReductionAnalysis):
"""Widget for t-SNE Plot analysis."""

def show_widget(self):
"""Show the widget and gather parameters."""
super().show_widget()

n_iter = st.select_slider(
"Maximum number of iterations for the optimization",
range(250, 2001),
value=1000,
)
perplexity = st.select_slider("Perplexity", range(5, 51), value=30)

self._parameters.update(
{
"n_iter": n_iter,
"perplexity": perplexity,
}
)

def do_analysis(self):
"""Draw t-SNE Plot using the TSNEPlot class."""
tsne_plot = self._dataset.plot_tsne(
group=self._parameters["group"],
circle=self._parameters["circle"],
perplexity=self._parameters["perplexity"],
n_iter=self._parameters["n_iter"],
)
return tsne_plot, None, self._parameters


class VolcanoPlotAnalysis(GroupCompareAnalysis):
"""Widget for Volcano Plot analysis."""

Expand Down
55 changes: 21 additions & 34 deletions alphastats/gui/utils/analysis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pandas as pd
import streamlit as st

from alphastats.gui.utils.analysis import VolcanoPlotAnalysis
from alphastats.gui.utils.analysis import (
PCAPlotAnalysis,
PlottingOptions,
TSNEPlotAnalysis,
UMAPPlotAnalysis,
VolcanoPlotAnalysis,
)
from alphastats.gui.utils.ui_helper import StateKeys, convert_df
from alphastats.keys import Cols

Expand Down Expand Up @@ -129,30 +135,32 @@ def do_analysis(
Currently, analysis_object is only not-None for Volcano Plot.
# TODO unify the API of all analysis methods
"""
method_dict = options_dict.get(method)
options = {
PlottingOptions.VOLCANO_PLOT: VolcanoPlotAnalysis,
PlottingOptions.PCA_PLOT: PCAPlotAnalysis,
PlottingOptions.UMAP_PLOT: UMAPPlotAnalysis,
PlottingOptions.TSNE_PLOT: TSNEPlotAnalysis,
}

if method == "Volcano Plot":
analysis = VolcanoPlotAnalysis(st.session_state[StateKeys.DATASET])
if (analysis_class := options.get(method)) is not None:
analysis = analysis_class(st.session_state[StateKeys.DATASET])
analysis.show_widget()

if st.button("Run analysis .."):
return analysis.do_analysis()
with st.spinner("Running analysis .."):
return analysis.do_analysis()
return None, None, {}

elif method == "t-SNE Plot":
parameters = st_tsne_options(method_dict)
method_dict = options_dict.get(method)

elif method == "Differential Expression Analysis - T-test":
# old, to be refactored logic:
if method == "Differential Expression Analysis - T-test":
parameters = helper_compare_two_groups()
parameters.update({"method": "ttest"})

elif method == "Differential Expression Analysis - Wald-test":
parameters = helper_compare_two_groups()
parameters.update({"method": "wald"})

elif method == "PCA Plot" or method == "UMAP Plot":
parameters = helper_plot_dimensionality_reduction(method_dict=method_dict)

else:
parameters = st_general(method_dict=method_dict)

Expand All @@ -165,6 +173,7 @@ def do_analysis(
return None, None, {}


# TODO this can be deleted after all analysis adapted the new Pattern (cf. analysis.py:Analysis())
def helper_plot_dimensionality_reduction(method_dict):
group = st.selectbox(
method_dict["settings"]["group"].get("label"),
Expand Down Expand Up @@ -235,25 +244,3 @@ def helper_compare_two_groups():
chosen_parameter_dict.update({"group1": group1, "group2": group2})

return chosen_parameter_dict


def st_tsne_options(method_dict):
chosen_parameter_dict = helper_plot_dimensionality_reduction(
method_dict=method_dict
)

n_iter = st.select_slider(
"Maximum number of iterations for the optimization",
range(250, 2001),
value=1000,
)
perplexity = st.select_slider("Perplexity", range(5, 51), value=30)

chosen_parameter_dict.update(
{
"n_iter": n_iter,
"perplexity": perplexity,
}
)

return chosen_parameter_dict
31 changes: 0 additions & 31 deletions alphastats/gui/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,6 @@ def get_plotting_options(state):
},
"function": dataset.plot_intensity,
},
"PCA Plot": {
"settings": {
"group": {
"options": metadata_options,
"label": "Color according to",
},
"circle": {"label": "Circle"},
},
"function": dataset.plot_pca,
},
"UMAP Plot": {
"settings": {
"group": {
"options": metadata_options,
"label": "Color according to",
},
"circle": {"label": "Circle"},
},
"function": dataset.plot_umap,
},
"t-SNE Plot": {
"settings": {
"group": {
"options": metadata_options,
"label": "Color according to",
},
"circle": {"label": "Circle"},
},
"function": dataset.plot_tsne,
},
"Volcano Plot": {},
"Clustermap": {"function": dataset.plot_clustermap},
# "Dendrogram": {"function": state[StateKeys.DATASET].plot_dendrogram}, # TODO why commented?
}
Expand Down
3 changes: 2 additions & 1 deletion alphastats/plots/DimensionalityReduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import plotly.express as px
import plotly.graph_objects as go
import sklearn
from sklearn.manifold._t_sne import TSNE

from alphastats.DataSet_Preprocess import Preprocess
from alphastats.keys import Cols
Expand Down Expand Up @@ -119,7 +120,7 @@ def _pca(self):
}

def _tsne(self, **kwargs):
tsne = sklearn.manifold.TSNE(n_components=2, verbose=1, **kwargs)
tsne = TSNE(n_components=2, verbose=1, **kwargs)
self.components = tsne.fit_transform(self.prepared_df)
self.labels = {
"0": "Dimension 1",
Expand Down
Loading