Skip to content

Commit

Permalink
Add support for resuming with CLI/YAML (#830)
Browse files Browse the repository at this point in the history
* should work for sams but I don't have a yaml to test

* now it works for repex

* Added test for cli resume

* remove code commented out

* Added change log entry

* comment out test while we debug GHA

* added debugging to fix GHA

* Added more debugging

* wft GHA? it is imported!

* going to remove double import

* fix typo in f-string

* added some more detail on the changelog

* forgot how weird subprocess can be

* have perses-relative respect the log level set from environ

* remove debugging info, fix issues with subprocess

* fix formatting with black

* missed a fstring fix

* lets see if this helps with the license

* just running the test I need to make debugging faster

* setting back to running all the tests

* add note about LOGLEVEL as well

* added info to changelog about n_cycles
  • Loading branch information
mikemhenry authored Aug 2, 2021
1 parent 222e1c7 commit fd6bb4e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ The full release history can be viewed `at the GitHub yank releases page <https:

Bugfixes
^^^^^^^^
- (PR `#830 <https://github.com/choderalab/perses/pull/830>`_)
Added limited support for resuming simulations from the CLI.
Assumes simulations are only going to be resumed from the production step and not equilibration step.
To extend the simulation, change ``n_cycles`` to a larger number and re-run the CLI tool.
``LOGLEVEL`` can now be set with an environmental variable when using the CLI tool.
- (PR `#821 <https://github.com/choderalab/perses/pull/821>`_)
Added tests for the resume simulation functionality.
- (PR `#828 <https://github.com/choderalab/perses/pull/828>`_)
Expand Down
68 changes: 63 additions & 5 deletions perses/app/setup_relative_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def filter(self, record):

fmt = logging.Formatter(fmt="%(asctime)s:(%(relative)ss):%(name)s:%(message)s")
#logging.basicConfig(level = logging.NOTSET)
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
level=LOGLEVEL,
datefmt='%Y-%m-%d %H:%M:%S')
_logger = logging.getLogger()
_logger.setLevel(logging.INFO)
_logger.setLevel(LOGLEVEL)
[hndl.addFilter(TimeFilter()) for hndl in _logger.handlers]
[hndl.setFormatter(fmt) for hndl in _logger.handlers]

Expand Down Expand Up @@ -298,8 +299,7 @@ def getSetupOptions(filename):
_logger.info(f"\t'softcore_v2' not specified: default to 'False'")

_logger.info(f"\tCreating '{trajectory_directory}'...")
assert (not os.path.exists(trajectory_directory)), f'Output trajectory directory "{trajectory_directory}" already exists. Refusing to overwrite'
os.makedirs(trajectory_directory)
os.makedirs(trajectory_directory, exist_ok=True)


return setup_options
Expand Down Expand Up @@ -613,7 +613,7 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True):
else:
selection_indices = None

storage_name = str(trajectory_directory)+'/'+str(trajectory_prefix)+'-'+str(phase)+'.nc'
storage_name = f"{trajectory_directory}/{trajectory_prefix}-{phase}.nc"
_logger.info(f'\tstorage_name: {storage_name}')
_logger.info(f'\tselection_indices {selection_indices}')
_logger.info(f'\tcheckpoint interval {checkpoint_interval}')
Expand Down Expand Up @@ -681,6 +681,20 @@ def run(yaml_filename=None):

_logger.info(f"Getting setup options from {yaml_filename}")
setup_options = getSetupOptions(yaml_filename)

# The name of the reporter file includes the phase name, so we need to check each
# one
for phase in setup_options['phases']:
trajectory_directory = setup_options['trajectory_directory']
trajectory_prefix = setup_options['trajectory_prefix']
reporter_file = f"{trajectory_directory}/{trajectory_prefix}-{phase}.nc"
# Once we find one, we are good to resume the simulation
if os.path.isfile(reporter_file):
_resume_run(setup_options)
# There is a loop in _resume_run for each phase so once we extend each phase
# we are done
exit()

if 'lambdas' in setup_options:
if type(setup_options['lambdas']) == int:
lambdas = {}
Expand Down Expand Up @@ -874,5 +888,49 @@ def run(yaml_filename=None):

_logger.info(f"\t\tFinished phase {phase}")

def _resume_run(setup_options):
if setup_options['fe_type'] == 'sams':
logZ = dict()
free_energies = dict()

_logger.info(f"Iterating through phases for sams...")
for phase in setup_options['phases']:
trajectory_directory = setup_options['trajectory_directory']
trajectory_prefix = setup_options['trajectory_prefix']

reporter_file = f"{trajectory_directory}/{trajectory_prefix}-{phase}.nc"
reporter = MultiStateReporter(reporter_file)
simulation = HybridSAMSSampler.from_storage(reporter)
total_steps = setup_options['n_cycles']
run_so_far = simulation.iteration
left_to_do = total_steps - run_so_far
_logger.info(f"\t\textending simulation...\n\n")
simulation.extend(n_iterations=left_to_do)
logZ[phase] = simulation._logZ[-1] - simulation._logZ[0]
free_energies[phase] = simulation._last_mbar_f_k[-1] - simulation._last_mbar_f_k[0]
_logger.info(f"\t\tFinished phase {phase}")
for phase in free_energies:
print(f"Comparing ligand {setup_options['old_ligand_index']} to {setup_options['new_ligand_index']}")
print(f"{phase} phase has a free energy of {free_energies[phase]}")

elif setup_options['fe_type'] == 'repex':
for phase in setup_options['phases']:
print(f'Running {phase} phase')
trajectory_directory = setup_options['trajectory_directory']
trajectory_prefix = setup_options['trajectory_prefix']

reporter_file = f"{trajectory_directory}/{trajectory_prefix}-{phase}.nc"
reporter = MultiStateReporter(reporter_file)
simulation = HybridRepexSampler.from_storage(reporter)
total_steps = setup_options['n_cycles']
run_so_far = simulation.iteration
left_to_do = total_steps - run_so_far
_logger.info(f"\t\textending simulation...\n\n")
simulation.extend(n_iterations=left_to_do)
_logger.info(f"\n\n")
_logger.info(f"\t\tFinished phase {phase}")
else:
raise("Can't resume")

if __name__ == "__main__":
run()
24 changes: 22 additions & 2 deletions perses/tests/test_resume.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import shutil
import subprocess
import tempfile

Expand All @@ -17,7 +18,7 @@
from perses.samplers.multistate import HybridRepexSampler


def test_cli_resume():
def test_cli_resume_repex():

with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
Expand Down Expand Up @@ -60,9 +61,28 @@ def test_cli_resume():
y_doc = yaml.load(document, Loader=yaml.UnsafeLoader)
y_doc["protein_pdb"] = protein_pdb
y_doc["ligand_file"] = ligand_file

with open("test.yml", "w") as outfile:
yaml.dump(y_doc, outfile)
subprocess.run(["perses-relative", "test.yml"])

env = os.environ.copy()
if os.environ.get('GITHUB_ACTIONS', False):
shutil.copy("/home/runner/work/perses/perses/oe_license.txt", ".")
subprocess.run("perses-relative test.yml", shell=True, check=True, env=env)

# Now we change the yaml to run longer
y_doc["n_cycles"] = 20
with open("test.yml", "w") as outfile:
yaml.dump(y_doc, outfile)
subprocess.run("perses-relative test.yml", shell=True, check=True, env=env)

# Check to see if we have a total of 20
reporter = MultiStateReporter(
"cdk2_repex_hbonds/cdk2-vacuum.nc", checkpoint_interval=10
)
simulation = HybridRepexSampler.from_storage(reporter)

assert simulation.iteration == 20


def test_resume_small_molecule(tmp_path):
Expand Down

0 comments on commit fd6bb4e

Please sign in to comment.