Skip to content

Commit

Permalink
Caching parsed output files (#273)
Browse files Browse the repository at this point in the history
* caching parsed output files

* switch to absolute paths

* add tests for caching

* rename cache clearing method to avoid clashes

* fix lint
  • Loading branch information
gpetretto authored Nov 6, 2023
1 parent e36e48a commit 1d31ba5
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 17 deletions.
4 changes: 3 additions & 1 deletion custodian/custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monty.shutil import gzip_dir
from monty.tempfile import ScratchDir

from .utils import get_execution_host_info
from .utils import get_execution_host_info, tracked_lru_cache

__author__ = "Shyue Ping Ong, William Davidson Richards"
__copyright__ = "Copyright 2012, The Materials Project"
Expand Down Expand Up @@ -683,6 +683,8 @@ def _do_check(self, handlers, terminate_func=None):
self.run_log[-1]["corrections"].extend(corrections)
# We do a dump of the run log after each check.
dumpfn(self.run_log, Custodian.LOG_FILE, cls=MontyEncoder, indent=4)
# Clear all the cached values to avoid reusing them in a subsequent check
tracked_lru_cache.tracked_cache_clear()
return len(corrections) > 0


Expand Down
38 changes: 38 additions & 0 deletions custodian/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

from custodian.utils import tracked_lru_cache


class TrackedLruCacheTest(unittest.TestCase):
def setUp(self):
# clear cache before and after each test to avoid
# unexpected caching from other tests
tracked_lru_cache.tracked_cache_clear()

def test_cache_and_clear(self):
n_calls = 0

@tracked_lru_cache
def some_func(x):
nonlocal n_calls
n_calls += 1
return x

assert some_func(1) == 1
assert n_calls == 1
assert some_func(2) == 2
assert n_calls == 2
assert some_func(1) == 1
assert n_calls == 2

assert len(tracked_lru_cache.cached_functions) == 1

tracked_lru_cache.tracked_cache_clear()

assert len(tracked_lru_cache.cached_functions) == 0

assert some_func(1) == 1
assert n_calls == 3

def tearDown(self):
tracked_lru_cache.tracked_cache_clear()
45 changes: 45 additions & 0 deletions custodian/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility function and classes."""

import functools
import logging
import os
import tarfile
Expand Down Expand Up @@ -44,3 +45,47 @@ def get_execution_host_info():
except Exception:
pass
return host or "unknown", cluster or "unknown"


class tracked_lru_cache:
"""
Decorator wrapping the functools.lru_cache adding a tracking of the
functions that have been wrapped.
Exposes a method to clear the cache of all the wrapped functions.
Used to cache the parsed outputs in handlers/validators, to avoid
multiple parsing of the same file.
Allows Custodian to clear the cache after all the checks have been performed.
"""

cached_functions: set = set()

def __init__(self, func):
"""
Args:
func: function to be decorated
"""
self.func = functools.lru_cache(func)
functools.update_wrapper(self, func)

# expose standard lru_cache functions
self.cache_info = self.func.cache_info
self.cache_clear = self.func.cache_clear

def __call__(self, *args, **kwargs):
"""
Call the decorated function
"""
result = self.func(*args, **kwargs)
self.cached_functions.add(self.func)
return result

@classmethod
def tracked_cache_clear(cls):
"""
Clear the cache of all the decorated functions.
"""
while cls.cached_functions:
f = cls.cached_functions.pop()
f.cache_clear()
27 changes: 14 additions & 13 deletions custodian/vasp/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monty.serialization import loadfn
from pymatgen.core.structure import Structure
from pymatgen.io.vasp.inputs import Incar, Kpoints, Poscar, VaspInput
from pymatgen.io.vasp.outputs import Oszicar, Outcar, Vasprun
from pymatgen.io.vasp.outputs import Oszicar
from pymatgen.io.vasp.sets import MPScanRelaxSet
from pymatgen.transformations.standard_transformations import SupercellTransformation

Expand All @@ -31,6 +31,7 @@
from custodian.custodian import ErrorHandler
from custodian.utils import backup
from custodian.vasp.interpreter import VaspModder
from custodian.vasp.io import load_outcar, load_vasprun

__author__ = (
"Shyue Ping Ong, William Davidson Richards, Anubhav Jain, Wei Chen, "
Expand Down Expand Up @@ -214,7 +215,7 @@ def correct(self):
# error count to 1 to skip first fix
if self.error_count["brmix"] == 0:
try:
assert Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
assert load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
except Exception:
self.error_count["brmix"] += 1

Expand Down Expand Up @@ -510,7 +511,7 @@ def correct(self):
# resources, seems to be to just increase NCORE slightly. That's what I do here.
nprocs = multiprocessing.cpu_count()
try:
nelect = Outcar("OUTCAR").nelect
nelect = load_outcar(os.path.join(os.getcwd(), "OUTCAR")).nelect
except Exception:
nelect = 1 # dummy value
if nelect < nprocs:
Expand Down Expand Up @@ -706,7 +707,7 @@ def correct(self):

if (
"lrf_comm" in self.errors
and Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
and load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
and not vi["INCAR"].get("LPEAD")
):
actions.append({"dict": "INCAR", "action": {"_set": {"LPEAD": True}}})
Expand Down Expand Up @@ -897,7 +898,7 @@ def check(self):
self.max_drift = incar["EDIFFG"] * -1

try:
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
except Exception:
# Can't perform check if Outcar not valid
return False
Expand All @@ -917,7 +918,7 @@ def correct(self):
vi = VaspInput.from_directory(".")

incar = vi["INCAR"]
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))

# Move CONTCAR to POSCAR
actions.append({"file": "CONTCAR", "action": {"_file_copy": {"dest": "POSCAR"}}})
Expand Down Expand Up @@ -988,7 +989,7 @@ def check(self):
return False

try:
v = Vasprun(self.output_vasprun)
v = load_vasprun(os.path.join(os.getcwd(), self.output_vasprun))
if v.converged:
return False
except Exception:
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
if not v.converged:
return True
except Exception:
Expand All @@ -1040,7 +1041,7 @@ def check(self):

def correct(self):
"""Perform corrections."""
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
algo = v.incar.get("ALGO", "Normal").lower()
actions = []
if not v.converged_electronic:
Expand Down Expand Up @@ -1139,7 +1140,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
# check whether bandgap is zero, tetrahedron smearing was used
# and relaxation is performed.
if v.eigenvalue_band_properties[0] == 0 and v.incar.get("ISMEAR", 1) < -3 and v.incar.get("NSW", 0) > 1:
Expand Down Expand Up @@ -1186,7 +1187,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
# check whether bandgap is zero and KSPACING is too large
# using 0 as fallback value for KSPACING so that this handler does not trigger if KSPACING is not set
if v.eigenvalue_band_properties[0] == 0 and v.incar.get("KSPACING", 0) > 0.22:
Expand Down Expand Up @@ -1244,7 +1245,7 @@ def check(self):
"""Check for error."""
incar = Incar.from_file("INCAR")
try:
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
except Exception:
# Can't perform check if Outcar not valid
return False
Expand Down Expand Up @@ -1601,7 +1602,7 @@ def check(self):
if self.wall_time:
run_time = datetime.datetime.now() - self.start_time
total_secs = run_time.total_seconds()
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
if not self.electronic_step_stop:
# Determine max time per ionic step.
outcar.read_pattern({"timings": r"LOOP\+.+real time(.+)"}, postprocess=float)
Expand Down
38 changes: 38 additions & 0 deletions custodian/vasp/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Helper functions for dealing with vasp files.
"""

from pymatgen.io.vasp.outputs import Outcar, Vasprun

from custodian.utils import tracked_lru_cache


@tracked_lru_cache
def load_vasprun(filepath, **vasprun_kwargs):
"""
Load Vasprun object from file path.
Caches the output for reuse.
Args:
filepath: path to the vasprun.xml file.
**vasprun_kwargs: kwargs arguments passed to the Vasprun init.
Returns:
The Vasprun object
"""
return Vasprun(filepath, **vasprun_kwargs)


@tracked_lru_cache
def load_outcar(filepath):
"""
Load Outcar object from file path.
Caches the output for reuse.
Args:
filepath: path to the OUTCAR file.
Returns:
The Vasprun object
"""
return Outcar(filepath)
10 changes: 10 additions & 0 deletions custodian/vasp/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ def _patch_get_potential_energy(monkeypatch):
Monkeypatch the multiprocessing.cpu_count() function to always return 64
"""
monkeypatch.setattr(multiprocessing, "cpu_count", lambda *args, **kwargs: 64)


@pytest.fixture(autouse=True)
def _clear_tracked_cache():
"""
Clear the cache of the stored functions between the tests.
"""
from custodian.utils import tracked_lru_cache

tracked_lru_cache.tracked_cache_clear()
6 changes: 6 additions & 0 deletions custodian/vasp/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pymatgen.io.vasp.inputs import Incar, Kpoints, Structure, VaspInput
from pymatgen.util.testing import PymatgenTest

from custodian.utils import tracked_lru_cache
from custodian.vasp.handlers import (
AliasingErrorHandler,
DriftErrorHandler,
Expand Down Expand Up @@ -599,34 +600,39 @@ def test_check_correct_electronic(self):
"actions": [{"action": {"_set": {"ALGO": "Normal"}}, "dict": "INCAR"}],
"errors": ["Unconverged"],
}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_veryfast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "Fast"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_normal", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_metagga_fast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_hybrid_fast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_hybrid_all", "vasprun.xml")
handler = UnconvergedErrorHandler()
Expand Down
27 changes: 27 additions & 0 deletions custodian/vasp/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest

from custodian.utils import tracked_lru_cache
from custodian.vasp.io import load_outcar, load_vasprun

test_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "test_files")


class IOTest(unittest.TestCase):
def test_load_outcar(self):
outcar = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))
assert outcar is not None
outcar2 = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))

assert outcar is outcar2

assert len(tracked_lru_cache.cached_functions) == 1

def test_load_vasprun(self):
vr = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))
assert vr is not None
vr2 = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))

assert vr is vr2

assert len(tracked_lru_cache.cached_functions) == 1
7 changes: 4 additions & 3 deletions custodian/vasp/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import os
from collections import deque

from pymatgen.io.vasp import Chgcar, Incar, Outcar, Vasprun
from pymatgen.io.vasp import Chgcar, Incar

from custodian.custodian import Validator
from custodian.vasp.io import load_outcar, load_vasprun


class VasprunXMLValidator(Validator):
Expand All @@ -27,7 +28,7 @@ def __init__(self, output_file="vasp.out", stderr_file="std_err.txt"):
def check(self):
"""Check for error."""
try:
Vasprun("vasprun.xml")
load_vasprun(os.path.join(os.getcwd(), "vasprun.xml"))
except Exception:
exception_context = {}

Expand Down Expand Up @@ -88,7 +89,7 @@ def check(self):
if not is_npt:
return False

outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
patterns = {"MDALGO": r"MDALGO\s+=\s+([\d]+)"}
outcar.read_pattern(patterns=patterns)
if outcar.data["MDALGO"] == [["3"]]:
Expand Down

0 comments on commit 1d31ba5

Please sign in to comment.