diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 460a141f..5db46f69 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -72,13 +72,12 @@ def llm_config(): # ) # ) # TODO unused? - gene_to_prot_id_map = dict( + gene_to_prot_id_map = dict( # TODO move this logic to dataset zip( genes_of_interest_colored_df[gene_names_colname].tolist(), genes_of_interest_colored_df[prot_ids_colname].tolist(), ) ) - st.session_state[StateKeys.GENE_TO_PROT_ID] = gene_to_prot_id_map with c2: display_figure(volcano_plot.plot) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index f7ec94be..80462422 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -1,4 +1,3 @@ -import copy from typing import Dict, List import pandas as pd @@ -124,26 +123,23 @@ def get_general_assistant_functions() -> List[Dict]: def get_assistant_functions( - gene_to_prot_id_dict: Dict, + gene_to_prot_id_map: Dict, metadata: pd.DataFrame, subgroups_for_each_group: Dict, ) -> List[Dict]: """ Get a list of assistant functions for function calling in the ChatGPT model. - You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling Args: - gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. - metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). - subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). + gene_to_prot_id_map (dict): A dictionary with gene names as keys and protein IDs as values. + metadata (pd.DataFrame): The metadata dataframe (which sample has which disease/treatment/condition/etc). + subgroups_for_each_group (dict): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). Returns: list[dict]: A list of assistant functions. """ - # TODO figure out how this relates to the parameter `subgroups_for_each_group` - subgroups_for_each_group_ = str( - get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata) - ) + gene_names = list(gene_to_prot_id_map.keys()) + groups = [str(col) for col in metadata.columns.to_list()] return [ { "type": "function", @@ -153,21 +149,21 @@ def get_assistant_functions( "parameters": { "type": "object", "properties": { - "protein_id": { + "gene_name": { # this will be mapped to "protein_id" when calling the function "type": "string", - "enum": [i for i in gene_to_prot_id_dict], - "description": "Identifier for the protein of interest", + "enum": gene_names, + "description": "Identifier for the gene of interest", }, "group": { "type": "string", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, "description": "Column name in the dataset for the group variable", }, "subgroups": { "type": "array", "items": {"type": "string"}, "description": f"Specific subgroups within the group to analyze. For each group you need to look up the subgroups in the dict" - f" {subgroups_for_each_group_} or present user with them first if you are not sure what to choose", + f" {subgroups_for_each_group} or present user with them first if you are not sure what to choose", }, "method": { "type": "string", @@ -198,7 +194,7 @@ def get_assistant_functions( "group": { "type": "string", "description": "The name of the group column in the dataset", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, }, "method": { "type": "string", @@ -225,7 +221,7 @@ def get_assistant_functions( "color": { "type": "string", "description": "The name of the group column in the dataset to color the samples by", - "enum": [str(i) for i in metadata.columns.to_list()], + "enum": groups, }, "method": { "type": "string", @@ -303,29 +299,37 @@ def get_assistant_functions( def perform_dimensionality_reduction(group, method, circle, **kwargs): + dataset = st.session_state[StateKeys.DATASET] dr = DimensionalityReduction( - st.session_state[StateKeys.DATASET], group, method, circle, **kwargs + mat=dataset.mat, + metadate=dataset.metadata, + sample=dataset.sample, + preprocessing_info=dataset.preprocessing_info, + group=group, + circle=circle, + method=method, + **kwargs, ) return dr.plot -def get_gene_to_prot_id_mapping(gene_id: str) -> str: +def get_protein_id_for_gene_name( + gene_name: str, gene_to_prot_id_map: Dict[str, str] +) -> str: """Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id. 'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1' + Args: - gene_id (str): Gene id + gene_name (str): Gene id Returns: str: Protein id or gene id if not present in the mapping. """ - import streamlit as st + if gene_name in gene_to_prot_id_map: + return gene_to_prot_id_map[gene_name] + + for gene, protein_id in gene_to_prot_id_map.items(): + if gene_name in gene.split(";"): + return protein_id - session_state_copy = dict(copy.deepcopy(st.session_state)) - if StateKeys.GENE_TO_PROT_ID not in session_state_copy: - session_state_copy[StateKeys.GENE_TO_PROT_ID] = {} - if gene_id in session_state_copy[StateKeys.GENE_TO_PROT_ID]: - return session_state_copy[StateKeys.GENE_TO_PROT_ID][gene_id] - for gene, prot_id in session_state_copy[StateKeys.GENE_TO_PROT_ID].items(): - if gene_id in gene.split(";"): - return prot_id - return gene_id + return gene_name diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 5e7b45e2..764dede4 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -11,6 +11,7 @@ from alphastats.gui.utils.gpt_helper import ( get_assistant_functions, get_general_assistant_functions, + get_protein_id_for_gene_name, get_subgroups_for_each_group, perform_dimensionality_reduction, ) @@ -108,7 +109,7 @@ def _get_tools(self) -> List[Dict[str, Any]]: if self.metadata is not None and self._gene_to_prot_id_map is not None: tools += ( *get_assistant_functions( - gene_to_prot_id_dict=self._gene_to_prot_id_map, + gene_to_prot_id_map=self._gene_to_prot_id_map, metadata=self.metadata, subgroups_for_each_group=get_subgroups_for_each_group( self.metadata @@ -133,6 +134,7 @@ def truncate_conversation_history(self, max_tokens: int = 100000): """ total_tokens = sum(len(m["content"].split()) for m in self.messages) while total_tokens > max_tokens and len(self.messages) > 1: + # TODO messages should still be displayed! removed_message = self.messages.pop(0) total_tokens -= len(removed_message["content"].split()) @@ -179,24 +181,39 @@ def execute_function( If the function is not implemented or the dataset is not available """ try: - if function_name == "get_gene_function": - # TODO log whats going on - return get_gene_function(**function_args) - elif function_name == "get_enrichment_data": - return get_enrichment_data(**function_args) - elif function_name == "perform_dimensionality_reduction": - return perform_dimensionality_reduction(**function_args) - elif function_name.startswith("plot_") or function_name.startswith( - "perform_" - ): + # first try to find the function in the non-Dataset functions + if ( + function := { + "get_gene_function": get_gene_function, + "get_enrichment_data": get_enrichment_data, + "perform_dimensionality_reduction": perform_dimensionality_reduction, + }.get(function_name) + ) is not None: + return function(**function_args) + + # special treatment for this one + elif function_name == "plot_intensity": + gene_name = function_args.pop("gene_name") + protein_id = get_protein_id_for_gene_name( + gene_name, self._gene_to_prot_id_map + ) + function_args["protein_id"] = protein_id + + return self.dataset.plot_intensity(**function_args) + + # fallback: try to find the function in the Dataset functions + else: plot_function = getattr( - self.dataset, function_name.split(".")[-1], None + self.dataset, + function_name.split(".")[-1], + None, # TODO why split? ) if plot_function: return plot_function(**function_args) raise ValueError( f"Function {function_name} not implemented or dataset not available" ) + except Exception as e: return f"Error executing {function_name}: {str(e)}" @@ -219,6 +236,7 @@ def handle_function_calls( """ new_artifacts = {} + funcs_and_args = "\n".join( [ f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}" @@ -231,7 +249,6 @@ def handle_function_calls( for tool_call in tool_calls: function_name = tool_call.function.name - print(f"Calling function: {function_name}") function_args = json.loads(tool_call.function.arguments) function_result = self.execute_function(function_name, function_args) @@ -248,8 +265,10 @@ def handle_function_calls( "tool_call_id": tool_call.id, } ) + post_artefact_message_idx = len(self.messages) self.artifacts[post_artefact_message_idx] = new_artifacts.values() + logger.info( f"Calling 'chat.completions.create' {self.messages=} {self.tools=} .." ) diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index c38de188..bc40b4a8 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -83,9 +83,6 @@ def init_session_state() -> None: if StateKeys.USER_SESSION_ID not in st.session_state: st.session_state[StateKeys.USER_SESSION_ID] = str(uuid.uuid4()) - if StateKeys.GENE_TO_PROT_ID not in st.session_state: - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} - if StateKeys.ORGANISM not in st.session_state: st.session_state[StateKeys.ORGANISM] = 9606 # human @@ -97,7 +94,6 @@ class StateKeys: ## 02_Data Import # on 1st run ORGANISM = "organism" - GENE_TO_PROT_ID = "gene_to_prot_id" USER_SESSION_ID = "user_session_id" LOADER = "loader" # on sample run (function load_sample_data), removed on new session click diff --git a/alphastats/plots/IntensityPlot.py b/alphastats/plots/IntensityPlot.py index 895910c1..d3e96546 100644 --- a/alphastats/plots/IntensityPlot.py +++ b/alphastats/plots/IntensityPlot.py @@ -8,7 +8,6 @@ import plotly.graph_objects as go import scipy -from alphastats.gui.utils.gpt_helper import get_gene_to_prot_id_mapping from alphastats.plots.PlotUtils import PlotUtils, plotly_object plotly.io.templates["alphastats_colors"] = plotly.graph_objects.layout.Template( @@ -54,7 +53,7 @@ def __init__( self.intensity_column = intensity_column self.preprocessing_info = preprocessing_info - self.protein_id = get_gene_to_prot_id_mapping(protein_id) + self.protein_id = protein_id self.group = group self.subgroups = subgroups self.method = method diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 50751dd1..84669918 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -18,7 +18,6 @@ def test_page_02_loads_without_input(): assert at.session_state[StateKeys.ORGANISM] == 9606 assert at.session_state[StateKeys.USER_SESSION_ID] is not None - assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch("streamlit.file_uploader") @@ -31,7 +30,6 @@ def test_patched_page_02_loads_without_input(mock_file_uploader: MagicMock): assert at.session_state[StateKeys.ORGANISM] == 9606 assert at.session_state[StateKeys.USER_SESSION_ID] is not None - assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch( diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index cd62f86c..507bec40 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -13,7 +13,6 @@ from alphastats.DataSet import DataSet from alphastats.dataset_factory import DataSetFactory from alphastats.DataSet_Preprocess import PreprocessingStateKeys -from alphastats.gui.utils.ui_helper import StateKeys from alphastats.loader.AlphaPeptLoader import AlphaPeptLoader from alphastats.loader.DIANNLoader import DIANNLoader from alphastats.loader.FragPipeLoader import FragPipeLoader @@ -517,9 +516,6 @@ def test_plot_intenstity_subgroup(self): self.assertEqual(len(plot_dict.get("data")), 3) def test_plot_intensity_subgroup_gracefully_handle_one_group(self): - import streamlit as st - - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} plot = self.obj.plot_intensity( protein_id="K7ERI9;A0A024R0T8;P02654;K7EJI9;K7ELM9;K7EPF9;K7EKP1", group="disease", diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 3121e01b..7ca13bd0 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -9,10 +9,6 @@ from alphastats.gui.utils.uniprot_utils import extract_data, get_uniprot_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader -if StateKeys.GENE_TO_PROT_ID not in st.session_state: - st.session_state[StateKeys.GENE_TO_PROT_ID] = {} - - logger = logging.getLogger(__name__)