Skip to content

Commit

Permalink
Merge branch 'main' into samc_nb_runs
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Jun 30, 2024
2 parents 096b2fb + 22ab324 commit 0f1bf7b
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 8 deletions.
16 changes: 13 additions & 3 deletions mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class RDFTool(BaseTool):
name = "RDFTool"
description = (
"Calculate the radial distribution function (RDF) of a trajectory "
"of a protein with respect to water molecules."
"of a protein with respect to water molecules using the trajectory file ID "
"(trajectory_fileid) and optionally the topology file ID (topology_fileid). "
)
args_schema = RDFToolInput
path_registry: Optional[PathRegistry]
Expand All @@ -45,6 +46,9 @@ def _run(self, **input):
elif "Invalid file extension" in str(e):
print("File Extension Not Supported in RDF tool: ", str(e))
return ("Failed. File Extension Not Supported", str(e))
elif "not in path registry" in str(e):
print("File ID not in Path Registry in RDF tool: ", str(e))
return ("Failed. File ID not in Path Registry", str(e))
else:
raise ValueError(f"Error during inputs in RDF tool {e}")

Expand Down Expand Up @@ -106,6 +110,10 @@ def _arun(self, input):
pass

def validate_input(self, input):
input = input.get("input", input)

input = input.get("action_input", input)

trajectory_id = input.get("trajectory_fileid", None)

topology_id = input.get("topology_fileid", None)
Expand All @@ -115,7 +123,9 @@ def validate_input(self, input):
atom_indices = input.get("atom_indices", None)

if not trajectory_id:
raise ValueError("Incorrect Inputs: Trajectory file ID is required")
raise ValueError(
"Incorrect Inputs: Trajectory file ID ('trajectory_fileid')is required"
)

# check if trajectory id is valid
fileids = self.path_registry.list_path_names()
Expand All @@ -131,7 +141,7 @@ def validate_input(self, input):
if not topology_id:
raise ValueError(
"Incorrect Inputs: "
"Topology file is required for trajectory "
"Topology file (topology_fileid) is required for trajectory "
"file with extension {}".format(ending)
)
if topology_id not in fileids:
Expand Down
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def plot_rad_gyration(self, pdb_id: str) -> str:
f"{self.path_registry.ckpt_figures}/{plot_name}",
description=f"Plot of radii of gyration over time for {self.pdb_id}",
)
plt.close()
plt.clf()
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CleaningToolFunctionInput(BaseModel):
add_hydrogens: bool = Field(
True, description="Whether to add hydrogens to the file."
)
add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.")
add_hydrogens_ph: float = Field(7.0, description="pH at which hydrogens are added.")


class CleaningToolFunction(BaseTool):
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/preprocess_tools/uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def get_ids(
entry["primaryAccession"] for entry in accession
] if accession else []
if single_id:
return all_ids.pop()
return [all_ids[0]] if all_ids else []
return list(set(all_ids))

def get_gene_names(self, query: str, primary_accession: str | None = None) -> list:
Expand Down
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import langchain
import nest_asyncio
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self, llm, path_registry):
self.path_registry = path_registry

def _run(self, query) -> str:
nest_asyncio.apply()
return scholar2result_llm(self.llm, query, self.path_registry)

async def _arun(self, query) -> str:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/rdf/rdf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_preprocess/test_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_get_structure_info(query_uniprot):
)


def get_ids(query_uniprot):
def test_get_ids(query_uniprot):
hg_ids = [
"P84792",
"P02042",
Expand Down Expand Up @@ -508,7 +508,8 @@ def get_ids(query_uniprot):
]
all_ids = query_uniprot.get_ids("hemoglobin")
single_id = query_uniprot.get_ids("hemoglobin", single_id=True)
assert single_id in hg_ids
assert single_id[0] in hg_ids
assert len(single_id) == 1
assert all(i in all_ids for i in hg_ids)


Expand Down

0 comments on commit 0f1bf7b

Please sign in to comment.