From 8e4bc8af7eca8f0af285f14e98c10e7b04d1d30c Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Wed, 26 Jun 2024 22:10:14 -0700 Subject: [PATCH] Update secondary_structure.py --- .../analysis_tools/secondary_structure.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py b/mdagent/tools/base_tools/analysis_tools/secondary_structure.py index 92e6361e..c96f3702 100644 --- a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py +++ b/mdagent/tools/base_tools/analysis_tools/secondary_structure.py @@ -502,6 +502,9 @@ class AnalyzeProteinStructure(BaseTool): "a string, separated by commas. " "The output is a dictionary " "containing the requested analyses." + "Here are the valid options: " + "atoms, residues, chains, frames, bonds" + "The tool will provide counts for each." ) path_registry: PathRegistry = PathRegistry.get_instance() @@ -511,26 +514,16 @@ def __init__(self, path_registry: PathRegistry): def analyze_protein(self, traj, requested_analyses: list): result = {} - if "n_atoms" in requested_analyses: - result["n_atoms"] = traj.n_atoms - if "n_residues" in requested_analyses: - result["n_residues"] = traj.n_residues - if "n_chains" in requested_analyses: - result["n_chains"] = traj.n_chains - if "n_frames" in requested_analyses: - result["n_frames"] = traj.n_frames - if "time" in requested_analyses: - result["time"] = traj.time - if "time_step" in requested_analyses: - result["time_step"] = traj.time_step if "atoms" in requested_analyses: - result["atoms"] = traj.topology.atoms - if "bonds" in requested_analyses: - result["bonds"] = traj.topology.bonds - if "chains" in requested_analyses: - result["chains"] = traj.topology.chains + result['n_atoms'] = traj.n_atoms if "residues" in requested_analyses: - result["residues"] = traj.topology.residues + result['n_residues'] = traj.n_residues + if "chains" in requested_analyses: + result['n_chains'] = traj.n_chains + if "frames" in requested_analyses: + result['n_frames'] = traj.n_frames + if "bonds" in requested_analyses: + result['n_bonds'] = len([bond for bond in traj.topology.bonds]) return result def _run(