From 9db90d4ae820c26308ccf1c7dd81c3c86d02ee3e Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:47:58 +0200 Subject: [PATCH 1/8] decouple get_gene_to_prot_id_mapping from session state --- alphastats/gui/utils/gpt_helper.py | 54 ++++++++++++++---------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index f7ec94be..92a228a0 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -1,4 +1,3 @@ -import copy from typing import Dict, List import pandas as pd @@ -130,20 +129,17 @@ def get_assistant_functions( ) -> List[Dict]: """ Get a list of assistant functions for function calling in the ChatGPT model. - You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling Args: - gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. - metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). - subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). + gene_to_prot_id_dict (dict): A dictionary with gene names as keys and protein IDs as values. + metadata (pd.DataFrame): The metadata dataframe (which sample has which disease/treatment/condition/etc). + subgroups_for_each_group (dict): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). Returns: list[dict]: A list of assistant functions. """ - # TODO figure out how this relates to the parameter `subgroups_for_each_group` - subgroups_for_each_group_ = str( - get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata) - ) + gene_names = list(gene_to_prot_id_dict.keys()) + groups = [str(col) for col in metadata.columns.to_list()] return [ { "type": "function", @@ -153,21 +149,21 @@ def get_assistant_functions( "parameters": { "type": "object", "properties": { - "protein_id": { + "gene_name": { # this will be mapped to "protein_id" when calling the function "type": "string", - "enum": [i for i in gene_to_prot_id_dict], - "description": "Identifier for the protein of interest", + "enum": gene_names, + "description": "Identifier for the gene of interest", }, "group": { "type": "string", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, "description": "Column name in the dataset for the group variable", }, "subgroups": { "type": "array", "items": {"type": "string"}, "description": f"Specific subgroups within the group to analyze. For each group you need to look up the subgroups in the dict" - f" {subgroups_for_each_group_} or present user with them first if you are not sure what to choose", + f" {subgroups_for_each_group} or present user with them first if you are not sure what to choose", }, "method": { "type": "string", @@ -198,7 +194,7 @@ def get_assistant_functions( "group": { "type": "string", "description": "The name of the group column in the dataset", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, }, "method": { "type": "string", @@ -225,7 +221,7 @@ def get_assistant_functions( "color": { "type": "string", "description": "The name of the group column in the dataset to color the samples by", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, }, "method": { "type": "string", @@ -303,29 +299,29 @@ def get_assistant_functions( def perform_dimensionality_reduction(group, method, circle, **kwargs): - dr = DimensionalityReduction( + dr = DimensionalityReduction( # TODO fix this call st.session_state[StateKeys.DATASET], group, method, circle, **kwargs ) return dr.plot -def get_gene_to_prot_id_mapping(gene_id: str) -> str: +def get_gene_to_prot_id_mapping( + gene_name: str, gene_to_prot_id_map: Dict[str, str] +) -> str: """Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id. 'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1' + Args: - gene_id (str): Gene id + gene_name (str): Gene id Returns: str: Protein id or gene id if not present in the mapping. """ - import streamlit as st + if gene_name in gene_to_prot_id_map: + return gene_to_prot_id_map[gene_name] + + for gene, protein_id in gene_to_prot_id_map.items(): + if gene_name in gene.split(";"): + return protein_id - session_state_copy = dict(copy.deepcopy(st.session_state)) - if StateKeys.GENE_TO_PROT_ID not in session_state_copy: - session_state_copy[StateKeys.GENE_TO_PROT_ID] = {} - if gene_id in session_state_copy[StateKeys.GENE_TO_PROT_ID]: - return session_state_copy[StateKeys.GENE_TO_PROT_ID][gene_id] - for gene, prot_id in session_state_copy[StateKeys.GENE_TO_PROT_ID].items(): - if gene_id in gene.split(";"): - return prot_id - return gene_id + return gene_name From 33c5cf53c77b42af3c5b60949e054e948930a8e4 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:48:51 +0200 Subject: [PATCH 2/8] move mapping of gene name to execute_function --- alphastats/gui/utils/ollama_utils.py | 43 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 5e7b45e2..4da18ba3 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -10,6 +10,7 @@ from alphastats.gui.utils.enrichment_analysis import get_enrichment_data from alphastats.gui.utils.gpt_helper import ( get_assistant_functions, + get_gene_to_prot_id_mapping, get_general_assistant_functions, get_subgroups_for_each_group, perform_dimensionality_reduction, @@ -133,6 +134,7 @@ def truncate_conversation_history(self, max_tokens: int = 100000): """ total_tokens = sum(len(m["content"].split()) for m in self.messages) while total_tokens > max_tokens and len(self.messages) > 1: + # TODO messages should still be displayed! removed_message = self.messages.pop(0) total_tokens -= len(removed_message["content"].split()) @@ -179,24 +181,39 @@ def execute_function( If the function is not implemented or the dataset is not available """ try: - if function_name == "get_gene_function": - # TODO log whats going on - return get_gene_function(**function_args) - elif function_name == "get_enrichment_data": - return get_enrichment_data(**function_args) - elif function_name == "perform_dimensionality_reduction": - return perform_dimensionality_reduction(**function_args) - elif function_name.startswith("plot_") or function_name.startswith( - "perform_" - ): + # first try to find the function in the non-Dataset functions + if ( + function := { + "get_gene_function": get_gene_function, + "get_enrichment_data": get_enrichment_data, + "perform_dimensionality_reduction": perform_dimensionality_reduction, + }.get(function_name) + ) is not None: + return function(**function_args) + + # special treatment for this one + elif function_name == "plot_intensity": + gene_name = function_args.pop("gene_name") + protein_id = get_gene_to_prot_id_mapping( + gene_name, self._gene_to_prot_id_map + ) + function_args["protein_id"] = protein_id + + return self.dataset.plot_intensity(**function_args) + + # fallback: try to find the function in the Dataset functions + else: plot_function = getattr( - self.dataset, function_name.split(".")[-1], None + self.dataset, + function_name.split(".")[-1], + None, # TODO why split? ) if plot_function: return plot_function(**function_args) raise ValueError( f"Function {function_name} not implemented or dataset not available" ) + except Exception as e: return f"Error executing {function_name}: {str(e)}" @@ -219,6 +236,7 @@ def handle_function_calls( """ new_artifacts = {} + funcs_and_args = "\n".join( [ f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}" @@ -231,7 +249,6 @@ def handle_function_calls( for tool_call in tool_calls: function_name = tool_call.function.name - print(f"Calling function: {function_name}") function_args = json.loads(tool_call.function.arguments) function_result = self.execute_function(function_name, function_args) @@ -248,8 +265,10 @@ def handle_function_calls( "tool_call_id": tool_call.id, } ) + post_artefact_message_idx = len(self.messages) self.artifacts[post_artefact_message_idx] = new_artifacts.values() + logger.info( f"Calling 'chat.completions.create' {self.messages=} {self.tools=} .." ) From ba65923854890abdf2d9fc8b16a1aeea547ab79d Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:50:47 +0200 Subject: [PATCH 3/8] remove GENE_TO_PROT_ID from session state --- alphastats/gui/pages/05_LLM.py | 1 - alphastats/gui/utils/ui_helper.py | 4 ---- tests/gui/test_02_import_data.py | 2 -- tests/test_DataSet.py | 4 ---- tests/test_gpt.py | 4 ---- 5 files changed, 15 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 460a141f..bbe6ba21 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -78,7 +78,6 @@ def llm_config(): genes_of_interest_colored_df[prot_ids_colname].tolist(), ) ) - st.session_state[StateKeys.GENE_TO_PROT_ID] = gene_to_prot_id_map with c2: display_figure(volcano_plot.plot) diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index c38de188..bc40b4a8 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -83,9 +83,6 @@ def init_session_state() -> None: if StateKeys.USER_SESSION_ID not in st.session_state: st.session_state[StateKeys.USER_SESSION_ID] = str(uuid.uuid4()) - if StateKeys.GENE_TO_PROT_ID not in st.session_state: - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} - if StateKeys.ORGANISM not in st.session_state: st.session_state[StateKeys.ORGANISM] = 9606 # human @@ -97,7 +94,6 @@ class StateKeys: ## 02_Data Import # on 1st run ORGANISM = "organism" - GENE_TO_PROT_ID = "gene_to_prot_id" USER_SESSION_ID = "user_session_id" LOADER = "loader" # on sample run (function load_sample_data), removed on new session click diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 50751dd1..84669918 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -18,7 +18,6 @@ def test_page_02_loads_without_input(): assert at.session_state[StateKeys.ORGANISM] == 9606 assert at.session_state[StateKeys.USER_SESSION_ID] is not None - assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch("streamlit.file_uploader") @@ -31,7 +30,6 @@ def test_patched_page_02_loads_without_input(mock_file_uploader: MagicMock): assert at.session_state[StateKeys.ORGANISM] == 9606 assert at.session_state[StateKeys.USER_SESSION_ID] is not None - assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch( diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index cd62f86c..507bec40 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -13,7 +13,6 @@ from alphastats.DataSet import DataSet from alphastats.dataset_factory import DataSetFactory from alphastats.DataSet_Preprocess import PreprocessingStateKeys -from alphastats.gui.utils.ui_helper import StateKeys from alphastats.loader.AlphaPeptLoader import AlphaPeptLoader from alphastats.loader.DIANNLoader import DIANNLoader from alphastats.loader.FragPipeLoader import FragPipeLoader @@ -517,9 +516,6 @@ def test_plot_intenstity_subgroup(self): self.assertEqual(len(plot_dict.get("data")), 3) def test_plot_intensity_subgroup_gracefully_handle_one_group(self): - import streamlit as st - - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} plot = self.obj.plot_intensity( protein_id="K7ERI9;A0A024R0T8;P02654;K7EJI9;K7ELM9;K7EPF9;K7EKP1", group="disease", diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 3121e01b..7ca13bd0 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -9,10 +9,6 @@ from alphastats.gui.utils.uniprot_utils import extract_data, get_uniprot_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader -if StateKeys.GENE_TO_PROT_ID not in st.session_state: - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} - - logger = logging.getLogger(__name__) From 4d05375d676efb425d52829d2d52e3dd0e529d04 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:51:47 +0200 Subject: [PATCH 4/8] renaming --- alphastats/gui/utils/gpt_helper.py | 6 +++--- alphastats/gui/utils/ollama_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 92a228a0..c1af1c93 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -123,7 +123,7 @@ def get_general_assistant_functions() -> List[Dict]: def get_assistant_functions( - gene_to_prot_id_dict: Dict, + gene_to_prot_id_map: Dict, metadata: pd.DataFrame, subgroups_for_each_group: Dict, ) -> List[Dict]: @@ -132,13 +132,13 @@ def get_assistant_functions( For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling Args: - gene_to_prot_id_dict (dict): A dictionary with gene names as keys and protein IDs as values. + gene_to_prot_id_map (dict): A dictionary with gene names as keys and protein IDs as values. metadata (pd.DataFrame): The metadata dataframe (which sample has which disease/treatment/condition/etc). subgroups_for_each_group (dict): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). Returns: list[dict]: A list of assistant functions. """ - gene_names = list(gene_to_prot_id_dict.keys()) + gene_names = list(gene_to_prot_id_map.keys()) groups = [str(col) for col in metadata.columns.to_list()] return [ { diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 4da18ba3..0ce8494d 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -109,7 +109,7 @@ def _get_tools(self) -> List[Dict[str, Any]]: if self.metadata is not None and self._gene_to_prot_id_map is not None: tools += ( *get_assistant_functions( - gene_to_prot_id_dict=self._gene_to_prot_id_map, + gene_to_prot_id_map=self._gene_to_prot_id_map, metadata=self.metadata, subgroups_for_each_group=get_subgroups_for_each_group( self.metadata From 3cc37bf627365c9686f9cef401b11f11ce0b16ea Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:53:26 +0200 Subject: [PATCH 5/8] renaming --- alphastats/gui/utils/gpt_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index c1af1c93..a932ea9f 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -305,7 +305,7 @@ def perform_dimensionality_reduction(group, method, circle, **kwargs): return dr.plot -def get_gene_to_prot_id_mapping( +def get_protein_id_for_gene_name( gene_name: str, gene_to_prot_id_map: Dict[str, str] ) -> str: """Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id. From b7f02a7abb9c0114802e509846602070633eac1c Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:53:37 +0200 Subject: [PATCH 6/8] adapt IntensityPlot.py --- alphastats/plots/IntensityPlot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/alphastats/plots/IntensityPlot.py b/alphastats/plots/IntensityPlot.py index 895910c1..d3e96546 100644 --- a/alphastats/plots/IntensityPlot.py +++ b/alphastats/plots/IntensityPlot.py @@ -8,7 +8,6 @@ import plotly.graph_objects as go import scipy -from alphastats.gui.utils.gpt_helper import get_gene_to_prot_id_mapping from alphastats.plots.PlotUtils import PlotUtils, plotly_object plotly.io.templates["alphastats_colors"] = plotly.graph_objects.layout.Template( @@ -54,7 +53,7 @@ def __init__( self.intensity_column = intensity_column self.preprocessing_info = preprocessing_info - self.protein_id = get_gene_to_prot_id_mapping(protein_id) + self.protein_id = protein_id self.group = group self.subgroups = subgroups self.method = method From 9d0903e01fd62130b3c4edb027fe33bd90761615 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:53:52 +0200 Subject: [PATCH 7/8] renaming --- alphastats/gui/utils/ollama_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 0ce8494d..764dede4 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -10,8 +10,8 @@ from alphastats.gui.utils.enrichment_analysis import get_enrichment_data from alphastats.gui.utils.gpt_helper import ( get_assistant_functions, - get_gene_to_prot_id_mapping, get_general_assistant_functions, + get_protein_id_for_gene_name, get_subgroups_for_each_group, perform_dimensionality_reduction, ) @@ -194,7 +194,7 @@ def execute_function( # special treatment for this one elif function_name == "plot_intensity": gene_name = function_args.pop("gene_name") - protein_id = get_gene_to_prot_id_mapping( + protein_id = get_protein_id_for_gene_name( gene_name, self._gene_to_prot_id_map ) function_args["protein_id"] = protein_id From 3bd1e866e68e53fb7ecc4d77af535623238dce3c Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:01:11 +0200 Subject: [PATCH 8/8] fix perform_dimensionality_reduction --- alphastats/gui/pages/05_LLM.py | 2 +- alphastats/gui/utils/gpt_helper.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index bbe6ba21..5db46f69 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -72,7 +72,7 @@ def llm_config(): # ) # ) # TODO unused? - gene_to_prot_id_map = dict( + gene_to_prot_id_map = dict( # TODO move this logic to dataset zip( genes_of_interest_colored_df[gene_names_colname].tolist(), genes_of_interest_colored_df[prot_ids_colname].tolist(), diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index a932ea9f..80462422 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -299,8 +299,16 @@ def get_assistant_functions( def perform_dimensionality_reduction(group, method, circle, **kwargs): - dr = DimensionalityReduction( # TODO fix this call - st.session_state[StateKeys.DATASET], group, method, circle, **kwargs + dataset = st.session_state[StateKeys.DATASET] + dr = DimensionalityReduction( + mat=dataset.mat, + metadate=dataset.metadata, + sample=dataset.sample, + preprocessing_info=dataset.preprocessing_info, + group=group, + circle=circle, + method=method, + **kwargs, ) return dr.plot