From 22ab324d4b402d4ac958783ad19f8aa45003c5a5 Mon Sep 17 00:00:00 2001 From: Jorge <97254349+Jgmedina95@users.noreply.github.com> Date: Sat, 29 Jun 2024 18:16:24 -0400 Subject: [PATCH] Improving rdf and cleaning tools (#146) Debugging: pqa, cleaningtool, rdftool, and uniprot little bug in testing --------- Co-authored-by: Sam Cox --- .../tools/base_tools/analysis_tools/rdf_tool.py | 16 +++++++++++++--- mdagent/tools/base_tools/analysis_tools/rgy.py | 2 ++ .../base_tools/preprocess_tools/clean_tools.py | 2 +- .../tools/base_tools/preprocess_tools/uniprot.py | 2 +- .../tools/base_tools/util_tools/search_tools.py | 2 ++ notebooks/rdf/rdf.ipynb | 2 +- tests/test_preprocess/test_uniprot.py | 5 +++-- 7 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py b/mdagent/tools/base_tools/analysis_tools/rdf_tool.py index 6bdc16cf..5355499c 100644 --- a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py +++ b/mdagent/tools/base_tools/analysis_tools/rdf_tool.py @@ -26,7 +26,8 @@ class RDFTool(BaseTool): name = "RDFTool" description = ( "Calculate the radial distribution function (RDF) of a trajectory " - "of a protein with respect to water molecules." + "of a protein with respect to water molecules using the trajectory file ID " + "(trajectory_fileid) and optionally the topology file ID (topology_fileid). " ) args_schema = RDFToolInput path_registry: Optional[PathRegistry] @@ -45,6 +46,9 @@ def _run(self, **input): elif "Invalid file extension" in str(e): print("File Extension Not Supported in RDF tool: ", str(e)) return ("Failed. File Extension Not Supported", str(e)) + elif "not in path registry" in str(e): + print("File ID not in Path Registry in RDF tool: ", str(e)) + return ("Failed. File ID not in Path Registry", str(e)) else: raise ValueError(f"Error during inputs in RDF tool {e}") @@ -106,6 +110,10 @@ def _arun(self, input): pass def validate_input(self, input): + input = input.get("input", input) + + input = input.get("action_input", input) + trajectory_id = input.get("trajectory_fileid", None) topology_id = input.get("topology_fileid", None) @@ -115,7 +123,9 @@ def validate_input(self, input): atom_indices = input.get("atom_indices", None) if not trajectory_id: - raise ValueError("Incorrect Inputs: Trajectory file ID is required") + raise ValueError( + "Incorrect Inputs: Trajectory file ID ('trajectory_fileid')is required" + ) # check if trajectory id is valid fileids = self.path_registry.list_path_names() @@ -131,7 +141,7 @@ def validate_input(self, input): if not topology_id: raise ValueError( "Incorrect Inputs: " - "Topology file is required for trajectory " + "Topology file (topology_fileid) is required for trajectory " "file with extension {}".format(ending) ) if topology_id not in fileids: diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index ffeba6a6..c0431ae1 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -92,6 +92,8 @@ def plot_rad_gyration(self, pdb_id: str) -> str: f"{self.path_registry.ckpt_figures}/{plot_name}", description=f"Plot of radii of gyration over time for {self.pdb_id}", ) + plt.close() + plt.clf() return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}" diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index 821295b9..1be094f3 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -30,7 +30,7 @@ class CleaningToolFunctionInput(BaseModel): add_hydrogens: bool = Field( True, description="Whether to add hydrogens to the file." ) - add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.") + add_hydrogens_ph: float = Field(7.0, description="pH at which hydrogens are added.") class CleaningToolFunction(BaseTool): diff --git a/mdagent/tools/base_tools/preprocess_tools/uniprot.py b/mdagent/tools/base_tools/preprocess_tools/uniprot.py index 03b939d0..28dfae69 100644 --- a/mdagent/tools/base_tools/preprocess_tools/uniprot.py +++ b/mdagent/tools/base_tools/preprocess_tools/uniprot.py @@ -697,7 +697,7 @@ def get_ids( entry["primaryAccession"] for entry in accession ] if accession else [] if single_id: - return all_ids.pop() + return [all_ids[0]] if all_ids else [] return list(set(all_ids)) def get_gene_names(self, query: str, primary_accession: str | None = None) -> list: diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index 9d687343..003c8847 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -3,6 +3,7 @@ from typing import Optional import langchain +import nest_asyncio import paperqa import paperscraper from langchain.base_language import BaseLanguageModel @@ -77,6 +78,7 @@ def __init__(self, llm, path_registry): self.path_registry = path_registry def _run(self, query) -> str: + nest_asyncio.apply() return scholar2result_llm(self.llm, query, self.path_registry) async def _arun(self, query) -> str: diff --git a/notebooks/rdf/rdf.ipynb b/notebooks/rdf/rdf.ipynb index 79258b92..dca7396e 100644 --- a/notebooks/rdf/rdf.ipynb +++ b/notebooks/rdf/rdf.ipynb @@ -99,7 +99,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/tests/test_preprocess/test_uniprot.py b/tests/test_preprocess/test_uniprot.py index e4b118bc..b82b74c5 100644 --- a/tests/test_preprocess/test_uniprot.py +++ b/tests/test_preprocess/test_uniprot.py @@ -478,7 +478,7 @@ def test_get_structure_info(query_uniprot): ) -def get_ids(query_uniprot): +def test_get_ids(query_uniprot): hg_ids = [ "P84792", "P02042", @@ -508,7 +508,8 @@ def get_ids(query_uniprot): ] all_ids = query_uniprot.get_ids("hemoglobin") single_id = query_uniprot.get_ids("hemoglobin", single_id=True) - assert single_id in hg_ids + assert single_id[0] in hg_ids + assert len(single_id) == 1 assert all(i in all_ids for i in hg_ids)