Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor llm v #356

Open
wants to merge 8 commits into
base: refactor_llm_IV
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@
"Download as .csv", csv, method + ".csv", "text/csv", key="download-csv"
)

# TODO this is still quite rough, should be a list, mb add a button etc..
if method == "Volcano Plot" and analysis_result is not None:
st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters)

@st.fragment
def show_start_llm_button(method: str) -> None:
"""Show the button to start the LLM analysis."""

msg = (
"(this will overwrite the existing LLM analysis!)"
if StateKeys.LLM_INTEGRATION in st.session_state
else ""
)

submitted = st.button(
f"Analyse with LLM ... {msg}",
disabled=(method != "Volcano Plot"),
help="Interpret the current analysis with an LLM (available for 'Volcano Plot' only).",
)
if submitted:
if StateKeys.LLM_INTEGRATION in st.session_state:
del st.session_state[StateKeys.LLM_INTEGRATION]
st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters)

st.success("LLM analysis created!")
st.page_link("pages/05_LLM.py", label="=> Go to LLM page..")


if analysis_result is not None:
show_start_llm_button(method)
54 changes: 30 additions & 24 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
st.stop()


st.markdown("### LLM Analysis")
st.markdown("### LLM")


@st.fragment
Expand All @@ -45,17 +45,19 @@ def llm_config():
st.info(f"Expecting Ollama API at {base_url}.")


st.markdown("#### Configure LLM")
llm_config()

st.markdown("#### Analysis")

st.markdown("#### Analysis Input")

if StateKeys.LLM_INPUT not in st.session_state:
st.info("Create a Volcano plot first using the 'Analysis' page.")
st.stop()

volcano_plot, parameter_dict = st.session_state[StateKeys.LLM_INPUT]

st.write(f"Parameters used for analysis: {parameter_dict}")
c1, c2 = st.columns((1, 2))

with c1:
Expand All @@ -65,13 +67,6 @@ def llm_config():
gene_names_colname = st.session_state[StateKeys.LOADER].gene_names
prot_ids_colname = st.session_state[StateKeys.LOADER].index_column

# st.session_state[StateKeys.PROT_ID_TO_GENE] = dict(
# zip(
# genes_of_interest_colored_df[prot_ids_colname].tolist(),
# genes_of_interest_colored_df[gene_names_colname].tolist(),
# )
# ) # TODO unused?

gene_to_prot_id_map = dict( # TODO move this logic to dataset
zip(
genes_of_interest_colored_df[gene_names_colname].tolist(),
Expand All @@ -87,7 +82,6 @@ def llm_config():
st.text("No proteins of interest found.")
st.stop()

# st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism)
upregulated_genes = [
key
for key in genes_of_interest_colored
Expand All @@ -99,7 +93,7 @@ def llm_config():
if genes_of_interest_colored[key] == "down"
]

st.subheader("Genes of interest")
st.markdown("##### Genes of interest")
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
st.write("Upregulated genes")
Expand All @@ -109,12 +103,14 @@ def llm_config():
display_proteins([], downregulated_genes)


st.subheader("Prompts generated based on gene functions")

st.markdown("##### Prompts generated based on analysis input")

with st.expander("System message", expanded=False):
system_message = st.text_area(
"", value=get_system_message(st.session_state[StateKeys.DATASET]), height=150
"",
value=get_system_message(st.session_state[StateKeys.DATASET]),
height=150,
disabled=StateKeys.LLM_INTEGRATION in st.session_state,
)

with st.expander("Initial prompt", expanded=True):
Expand All @@ -124,17 +120,20 @@ def llm_config():
parameter_dict, upregulated_genes, downregulated_genes
),
height=200,
disabled=StateKeys.LLM_INTEGRATION in st.session_state,
)

llm_submitted = st.button("Run LLM analysis")

st.markdown("##### LLM Analysis")

llm_submitted = st.button("Run LLM analysis ...")

if StateKeys.LLM_INTEGRATION not in st.session_state:
if not llm_submitted:
st.stop()

try:
llm = LLMIntegration(
llm_integration = LLMIntegration(
api_type=st.session_state[StateKeys.API_TYPE],
system_message=system_message,
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
Expand All @@ -143,13 +142,14 @@ def llm_config():
gene_to_prot_id_map=gene_to_prot_id_map,
)

st.session_state[StateKeys.LLM_INTEGRATION] = llm
st.session_state[StateKeys.LLM_INTEGRATION] = llm_integration

st.success(
f"{st.session_state[StateKeys.API_TYPE].upper()} integration initialized successfully!"
f"{st.session_state[StateKeys.API_TYPE]} integration initialized successfully!"
)

llm.chat_completion(initial_prompt)
with st.spinner("Processing initial prompt..."):
llm_integration.chat_completion(initial_prompt)

except AuthenticationError:
st.warning(
Expand All @@ -159,11 +159,10 @@ def llm_config():


@st.fragment
def llm_chat():
def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
"""The chat interface for the LLM analysis."""
llm = st.session_state[StateKeys.LLM_INTEGRATION]

for message in llm.get_print_view(show_all=False):
for message in llm_integration.get_print_view(show_all=show_all):
with st.chat_message(message["role"]):
st.markdown(message["content"])
for artifact in message["artifacts"]:
Expand All @@ -176,8 +175,15 @@ def llm_chat():
st.write(artifact)

if prompt := st.chat_input("Say something"):
llm.chat_completion(prompt)
with st.spinner("Processing prompt..."):
llm_integration.chat_completion(prompt)
st.rerun(scope="fragment")


llm_chat()
show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION], show_all)
15 changes: 8 additions & 7 deletions alphastats/gui/utils/analysis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,10 @@ def gui_volcano_plot() -> Tuple[Optional[Any], Optional[Any], Optional[Dict]]:
alpha = st.number_input(
label="alpha", min_value=0.001, max_value=0.050, value=0.050
)
chosen_parameter_dict.update({"alpha": alpha})

min_fc = st.select_slider("Foldchange cutoff", range(0, 3), value=1)

plotting_parameter_dict = {
"labels": labels,
"draw_line": draw_line,
"alpha": alpha,
"min_fc": min_fc,
}
chosen_parameter_dict.update({"min_fc": min_fc})

if method == "sam":
perm = st.number_input(
Expand All @@ -191,6 +186,12 @@ def gui_volcano_plot() -> Tuple[Optional[Any], Optional[Any], Optional[Dict]]:
volcano_plot = gui_volcano_plot_differential_expression_analysis(
chosen_parameter_dict
)
plotting_parameter_dict = {
"labels": labels,
"draw_line": draw_line,
"alpha": alpha,
"min_fc": min_fc,
}
volcano_plot._update(plotting_parameter_dict)
volcano_plot._annotate_result_df()
volcano_plot._plot()
Expand Down
Loading
Loading