Skip to content

Commit

Permalink
Merge branch 'main' into recombinant_install_guide
Browse files Browse the repository at this point in the history
  • Loading branch information
dwhswenson authored Oct 2, 2023
2 parents cce4bf7 + 7f830c8 commit 0e8b9f0
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 13 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ include openfecli/tests/data/*.json
include openfecli/tests/data/*.tar.gz
recursive-include openfecli/tests/ *.sdf
recursive-include openfecli/tests/ *.pdb
include openfe/tests/data/openmm_rfe/vacuum_nocoord.nc
16 changes: 11 additions & 5 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,15 @@ def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]:
def get_forward_and_reverse_energy_analysis(self) -> list[dict[str, Union[npt.NDArray, unit.Quantity]]]:
"""
Get a list of forward and reverse analysis of the free energies
for each repeat using uncorrolated production samples.
for each repeat using uncorrelated production samples.
The returned dicts have keys:
'fractions' - the fraction of data used for this estimate
'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate
'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty
The 'fractions' values are a numpy array, while the other arrays are
Quantity arrays, with units attached.
Returns
-------
Expand Down Expand Up @@ -231,9 +239,7 @@ def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]:
return overlap_stats

def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]:
"""
Returns the replica lambda state transition statistics for each
repeat.
"""The replica lambda state transition statistics for each repeat.
Note
----
Expand All @@ -246,7 +252,7 @@ def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]:
A list of dictionaries containing the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switchin
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
try:
Expand Down
Binary file not shown.
Binary file removed openfe/tests/data/openmm_rfe/vac_results.json.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,5 @@ def transformation_json() -> str:
"""string of a result of quickrun"""
d = resources.files('openfe.tests.data.openmm_rfe')

with gzip.open((d / 'vac_results.json.gz').as_posix(), 'r') as f: # type: ignore
with gzip.open((d / 'Transformation-e1702a3efc0fa735d5c14fc7572b5278_results.json.gz').as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore
103 changes: 98 additions & 5 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,10 +1241,103 @@ def test_constraints(tyk2_xml, tyk2_reference_xml):
assert float(a.get('d')) == pytest.approx(float(b.get('d')))


def test_reload_protocol_result(transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result'])
pr = openfe.ProtocolResult.from_dict(d['protocol_result'])

assert pr
return pr

def test_reload_protocol_result(self, transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result'])

assert pr

def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-15.768768285032115)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.03662634237353985)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()

assert isinstance(inds, list)
assert len(inds) == 3
for e, u in inds:
assert e.is_compatible_with(unit.kilojoule_per_mole)
assert u.is_compatible_with(unit.kilojoule_per_mole)

def test_get_forwards_etc(self, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()

assert isinstance(far, list)
far1 = far[0]
assert isinstance(far1, dict)
for k in ['fractions', 'forward_DGs', 'forward_dDGs',
'reverse_DGs', 'reverse_dDGs']:
assert k in far1

if k == 'fractions':
assert isinstance(far1[k], np.ndarray)
else:
assert isinstance(far1[k], unit.Quantity)
assert far1[k].is_compatible_with(unit.kilojoule_per_mole)

def test_get_overlap_matrices(self, protocolresult):
ovp = protocolresult.get_overlap_matrices()

assert isinstance(ovp, list)
assert len(ovp) == 3

ovp1 = ovp[0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (11,11)

def test_get_replica_transition_statistics(self, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()

assert isinstance(rpx, list)
assert len(rpx) == 3
rpx1 = rpx[0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (11,)
assert rpx1['matrix'].shape == (11, 11)

def test_get_replica_states(self, protocolresult):
rep = protocolresult.get_replica_states()

assert isinstance(rep, list)
assert len(rep) == 3
assert rep[0].shape == (6, 11)

def test_equilibration_iterations(self, protocolresult):
eq = protocolresult.equilibration_iterations()

assert isinstance(eq, list)
assert len(eq) == 3
assert all(isinstance(v, float) for v in eq)

def test_production_iterations(self, protocolresult):
prod = protocolresult.production_iterations()

assert isinstance(prod, list)
assert len(prod) == 3
assert all(isinstance(v, float) for v in prod)
6 changes: 4 additions & 2 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def _get_ddgs(legs):
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
else: # -no-cov-
raise RuntimeError(f"Unknown DDG type for {vals}")
else:
raise RuntimeError("Unable to determine type of RFE calculation "
f"for edges with labels {list(vals)} for "
f"ligands {ligpair}")

DDGs.append((*ligpair, DDGbind, bind_unc, DDGhyd, hyd_unc))

Expand Down
13 changes: 13 additions & 0 deletions openfecli/tests/commands/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
from importlib import resources
import tarfile
import pathlib
import pytest

from openfecli.commands.gather import (
Expand Down Expand Up @@ -106,3 +107,15 @@ def test_gather(results_dir, report):

assert set(expected.split(b'\n')) == actual_lines


def test_missing_leg_error(results_dir):
file_to_remove = "easy_rbfe_lig_ejm_31_complex_lig_ejm_42_complex.json"
(pathlib.Path("results") / file_to_remove).unlink()

runner = CliRunner()
result = runner.invoke(gather, ['results'] + ['-o', '-'])
assert result.exit_code == 1
assert isinstance(result.exception, RuntimeError)
assert "labels ['solvent']" in str(result.exception)
assert "'lig_ejm_31'" in str(result.exception)
assert "'lig_ejm_42'" in str(result.exception)

0 comments on commit 0e8b9f0

Please sign in to comment.