Skip to content

Commit

Permalink
various fixes (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Nov 4, 2024
1 parent 0f7baec commit 320a853
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 24 deletions.
2 changes: 0 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@
# OpenAI API Key
OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret

# PQA API Key to use LiteratureSearch tool (optional) -- it also requires OpenAI key
PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret

# Optional: add TogetherAI, Fireworks, or Anthropic API key here to use their models
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ repos:
- id: mixed-line-ending
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.270"
rev: "v0.7.1"
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/psf/black
rev: "23.3.0"
rev: "24.10.0"
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.3.0"
rev: "v1.13.0"
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies: [types-requests]
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
rev: "5.13.2"
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/Yelp/detect-secrets
rev: v1.0.3
rev: v1.5.0
hooks:
- id: detect-secrets
args: [--exclude-files, ".github/workflows/"]
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _run(self, file_id: str) -> str:
plotting_tools._find_file(file_id)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
if isinstance(plot_result, str):
return "Succeeded. IDs of figures created: " + plot_result
else:
return "Failed. No figures created."
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate_input(self, input):
)

if stride:
if type(stride) != int:
if not isinstance(stride, int):
try:
stride = int(stride)
if stride <= 0:
Expand Down
8 changes: 5 additions & 3 deletions mdagent/tools/base_tools/preprocess_tools/uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,11 @@ def get_ids(
if include_uniprotkbids:
all_ids + [entry["uniProtkbId"] for entry in ids_] if ids_ else []
accession = self.get_data(query, desired_field="accession")
all_ids + [
entry["primaryAccession"] for entry in accession
] if accession else []
(
all_ids + [entry["primaryAccession"] for entry in accession]
if accession
else []
)
if single_id:
return [all_ids[0]] if all_ids else []
return list(set(all_ids))
Expand Down
30 changes: 19 additions & 11 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,22 @@ def setup_system(self):
PME,
]:
if self.sim_params["Ensemble"] == "NPT":
self.system.addForce(
MonteCarloBarostat(
self.int_params["Pressure"],
self.int_params["Temperature"],
self.sim_params.get("barostatInterval", 25),
)
pressure = self.int_params.get("Pressure", 1.0)

if "Pressure" not in self.int_params:
print(
"Warning: 'Pressure' not provided. ",
"Using default pressure of 1.0 atm.",
)

self.system.addForce(
MonteCarloBarostat(
pressure,
self.int_params["Temperature"],
self.sim_params.get("barostatInterval", 25),
)
)

def setup_integrator(self):
print("Setting up integrator...")
int_params = self.int_params
Expand Down Expand Up @@ -1219,7 +1227,7 @@ def _process_parameters(self, user_params, param_type="system_params"):
)
if key == "constraints":
try:
if type(value) == str:
if isinstance(value, str):
if value == "None":
processed_params[key] = None
elif value == "HBonds":
Expand All @@ -1243,7 +1251,7 @@ def _process_parameters(self, user_params, param_type="system_params"):
"part of the parameters.\n"
)
if key == "rigidWater" or key == "rigidwater":
if type(value) == bool:
if isinstance(value, bool):
processed_params[key] = value
elif value == "True":
processed_params[key] = True
Expand All @@ -1268,7 +1276,7 @@ def _process_parameters(self, user_params, param_type="system_params"):
)
if key == "solvate":
try:
if type(value) == bool:
if isinstance(value, bool):
processed_params[key] = value
elif value == "True":
processed_params[key] = True
Expand Down Expand Up @@ -1480,7 +1488,7 @@ def check_system_params(cls, values):

# forcefield
forcefield_files = values.get("forcefield_files")
if forcefield_files is None or forcefield_files is []:
if forcefield_files is None or forcefield_files == []:
print("Setting default forcefields")
forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"]
elif len(forcefield_files) == 0:
Expand All @@ -1492,7 +1500,7 @@ def check_system_params(cls, values):
error_msg += "The forcefield file is not present"

save = values.get("save", True)
if type(save) != bool:
if not isinstance(save, bool):
error_msg += "save must be a boolean value"

if error_msg != "":
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def make_all_tools(
all_tools += [
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm),
]
if "OPENAI_API_KEY" in os.environ and "PQA_API_KEY" in os.environ:
if path_instance.ckpt_papers:
all_tools += [Scholar2ResultLLM(llm=llm, path_registry=path_instance)]
if human:
all_tools += [agents.load_tools(["human"], llm)[0]]
Expand Down

0 comments on commit 320a853

Please sign in to comment.