Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching parsed output files #273

Merged
merged 6 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion custodian/custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monty.shutil import gzip_dir
from monty.tempfile import ScratchDir

from .utils import get_execution_host_info
from .utils import get_execution_host_info, tracked_lru_cache

__author__ = "Shyue Ping Ong, William Davidson Richards"
__copyright__ = "Copyright 2012, The Materials Project"
Expand Down Expand Up @@ -683,6 +683,8 @@ def _do_check(self, handlers, terminate_func=None):
self.run_log[-1]["corrections"].extend(corrections)
# We do a dump of the run log after each check.
dumpfn(self.run_log, Custodian.LOG_FILE, cls=MontyEncoder, indent=4)
# Clear all the cached values to avoid reusing them in a subsequent check
tracked_lru_cache.tracked_cache_clear()
return len(corrections) > 0


Expand Down
38 changes: 38 additions & 0 deletions custodian/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

from custodian.utils import tracked_lru_cache


class TrackedLruCacheTest(unittest.TestCase):
def setUp(self):
# clear cache before and after each test to avoid
# unexpected caching from other tests
tracked_lru_cache.tracked_cache_clear()

def test_cache_and_clear(self):
n_calls = 0

@tracked_lru_cache
def some_func(x):
nonlocal n_calls
n_calls += 1
return x

assert some_func(1) == 1
assert n_calls == 1
assert some_func(2) == 2
assert n_calls == 2
assert some_func(1) == 1
assert n_calls == 2

assert len(tracked_lru_cache.cached_functions) == 1

tracked_lru_cache.tracked_cache_clear()

assert len(tracked_lru_cache.cached_functions) == 0

assert some_func(1) == 1
assert n_calls == 3

def tearDown(self):
tracked_lru_cache.tracked_cache_clear()
45 changes: 45 additions & 0 deletions custodian/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility function and classes."""

import functools
import logging
import os
import tarfile
Expand Down Expand Up @@ -44,3 +45,47 @@ def get_execution_host_info():
except Exception:
pass
return host or "unknown", cluster or "unknown"


class tracked_lru_cache:
"""
Decorator wrapping the functools.lru_cache adding a tracking of the
functions that have been wrapped.

Exposes a method to clear the cache of all the wrapped functions.

Used to cache the parsed outputs in handlers/validators, to avoid
multiple parsing of the same file.
Allows Custodian to clear the cache after all the checks have been performed.
"""

cached_functions: set = set()

def __init__(self, func):
"""
Args:
func: function to be decorated
"""
self.func = functools.lru_cache(func)
functools.update_wrapper(self, func)

# expose standard lru_cache functions
self.cache_info = self.func.cache_info
self.cache_clear = self.func.cache_clear

def __call__(self, *args, **kwargs):
"""
Call the decorated function
"""
result = self.func(*args, **kwargs)
self.cached_functions.add(self.func)
return result

@classmethod
def tracked_cache_clear(cls):
"""
Clear the cache of all the decorated functions.
"""
while cls.cached_functions:
f = cls.cached_functions.pop()
f.cache_clear()
27 changes: 14 additions & 13 deletions custodian/vasp/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monty.serialization import loadfn
from pymatgen.core.structure import Structure
from pymatgen.io.vasp.inputs import Incar, Kpoints, Poscar, VaspInput
from pymatgen.io.vasp.outputs import Oszicar, Outcar, Vasprun
from pymatgen.io.vasp.outputs import Oszicar
from pymatgen.io.vasp.sets import MPScanRelaxSet
from pymatgen.transformations.standard_transformations import SupercellTransformation

Expand All @@ -31,6 +31,7 @@
from custodian.custodian import ErrorHandler
from custodian.utils import backup
from custodian.vasp.interpreter import VaspModder
from custodian.vasp.io import load_outcar, load_vasprun

__author__ = (
"Shyue Ping Ong, William Davidson Richards, Anubhav Jain, Wei Chen, "
Expand Down Expand Up @@ -214,7 +215,7 @@ def correct(self):
# error count to 1 to skip first fix
if self.error_count["brmix"] == 0:
try:
assert Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
assert load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
except Exception:
self.error_count["brmix"] += 1

Expand Down Expand Up @@ -510,7 +511,7 @@ def correct(self):
# resources, seems to be to just increase NCORE slightly. That's what I do here.
nprocs = multiprocessing.cpu_count()
try:
nelect = Outcar("OUTCAR").nelect
nelect = load_outcar(os.path.join(os.getcwd(), "OUTCAR")).nelect
except Exception:
nelect = 1 # dummy value
if nelect < nprocs:
Expand Down Expand Up @@ -706,7 +707,7 @@ def correct(self):

if (
"lrf_comm" in self.errors
and Outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
and load_outcar(zpath(os.path.join(os.getcwd(), "OUTCAR"))).is_stopped is False
and not vi["INCAR"].get("LPEAD")
):
actions.append({"dict": "INCAR", "action": {"_set": {"LPEAD": True}}})
Expand Down Expand Up @@ -897,7 +898,7 @@ def check(self):
self.max_drift = incar["EDIFFG"] * -1

try:
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
except Exception:
# Can't perform check if Outcar not valid
return False
Expand All @@ -917,7 +918,7 @@ def correct(self):
vi = VaspInput.from_directory(".")

incar = vi["INCAR"]
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))

# Move CONTCAR to POSCAR
actions.append({"file": "CONTCAR", "action": {"_file_copy": {"dest": "POSCAR"}}})
Expand Down Expand Up @@ -988,7 +989,7 @@ def check(self):
return False

try:
v = Vasprun(self.output_vasprun)
v = load_vasprun(os.path.join(os.getcwd(), self.output_vasprun))
if v.converged:
return False
except Exception:
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
if not v.converged:
return True
except Exception:
Expand All @@ -1040,7 +1041,7 @@ def check(self):

def correct(self):
"""Perform corrections."""
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
algo = v.incar.get("ALGO", "Normal").lower()
actions = []
if not v.converged_electronic:
Expand Down Expand Up @@ -1139,7 +1140,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
# check whether bandgap is zero, tetrahedron smearing was used
# and relaxation is performed.
if v.eigenvalue_band_properties[0] == 0 and v.incar.get("ISMEAR", 1) < -3 and v.incar.get("NSW", 0) > 1:
Expand Down Expand Up @@ -1186,7 +1187,7 @@ def __init__(self, output_filename: str = "vasprun.xml"):
def check(self):
"""Check for error."""
try:
v = Vasprun(self.output_filename)
v = load_vasprun(os.path.join(os.getcwd(), self.output_filename))
# check whether bandgap is zero and KSPACING is too large
# using 0 as fallback value for KSPACING so that this handler does not trigger if KSPACING is not set
if v.eigenvalue_band_properties[0] == 0 and v.incar.get("KSPACING", 0) > 0.22:
Expand Down Expand Up @@ -1244,7 +1245,7 @@ def check(self):
"""Check for error."""
incar = Incar.from_file("INCAR")
try:
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
except Exception:
# Can't perform check if Outcar not valid
return False
Expand Down Expand Up @@ -1601,7 +1602,7 @@ def check(self):
if self.wall_time:
run_time = datetime.datetime.now() - self.start_time
total_secs = run_time.total_seconds()
outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
if not self.electronic_step_stop:
# Determine max time per ionic step.
outcar.read_pattern({"timings": r"LOOP\+.+real time(.+)"}, postprocess=float)
Expand Down
38 changes: 38 additions & 0 deletions custodian/vasp/io.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about instead of a new io module, we just define

ChachedOutcar = tracked_lru_cache(Outcar)
ChachedVasprun = tracked_lru_cache(Vasprun)

directly in custodian/vasp/handlers.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be done, but since this is also used in the validators module they should be imported from the handlers. I think that having a generic io module could be fine, but I have no problem moving them to handlers.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess no need to move to handlers.py.

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Helper functions for dealing with vasp files.
"""

from pymatgen.io.vasp.outputs import Outcar, Vasprun

from custodian.utils import tracked_lru_cache


@tracked_lru_cache
def load_vasprun(filepath, **vasprun_kwargs):
"""
Load Vasprun object from file path.
Caches the output for reuse.

Args:
filepath: path to the vasprun.xml file.
**vasprun_kwargs: kwargs arguments passed to the Vasprun init.

Returns:
The Vasprun object
"""
return Vasprun(filepath, **vasprun_kwargs)


@tracked_lru_cache
def load_outcar(filepath):
"""
Load Outcar object from file path.
Caches the output for reuse.

Args:
filepath: path to the OUTCAR file.

Returns:
The Vasprun object
"""
return Outcar(filepath)
10 changes: 10 additions & 0 deletions custodian/vasp/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ def _patch_get_potential_energy(monkeypatch):
Monkeypatch the multiprocessing.cpu_count() function to always return 64
"""
monkeypatch.setattr(multiprocessing, "cpu_count", lambda *args, **kwargs: 64)


@pytest.fixture(autouse=True)
def _clear_tracked_cache():
"""
Clear the cache of the stored functions between the tests.
"""
from custodian.utils import tracked_lru_cache

tracked_lru_cache.tracked_cache_clear()
Comment on lines +17 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this fixture currently runs on every test, not just the handler tests which would be unnecessary work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can move it to custodian/vasp/handlers.py but keep autouse=True to only auto apply it to those tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is also used in the validators, I think the fixture should either be duplicated in the handlers and validators files, or added it explicitly to the tests that need it, removing the autouse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably duplicating it in both files makes it more explicit too. People are more likely to see it and become aware this is happening.

6 changes: 6 additions & 0 deletions custodian/vasp/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pymatgen.io.vasp.inputs import Incar, Kpoints, Structure, VaspInput
from pymatgen.util.testing import PymatgenTest

from custodian.utils import tracked_lru_cache
from custodian.vasp.handlers import (
AliasingErrorHandler,
DriftErrorHandler,
Expand Down Expand Up @@ -599,34 +600,39 @@ def test_check_correct_electronic(self):
"actions": [{"action": {"_set": {"ALGO": "Normal"}}, "dict": "INCAR"}],
"errors": ["Unconverged"],
}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_veryfast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "Fast"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_normal", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_metagga_fast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_hybrid_fast", "vasprun.xml")
handler = UnconvergedErrorHandler()
assert handler.check()
dct = handler.correct()
assert dct["errors"] == ["Unconverged"]
assert dct == {"actions": [{"action": {"_set": {"ALGO": "All"}}, "dict": "INCAR"}], "errors": ["Unconverged"]}
tracked_lru_cache.tracked_cache_clear()

shutil.copy("vasprun.xml.electronic_hybrid_all", "vasprun.xml")
handler = UnconvergedErrorHandler()
Expand Down
27 changes: 27 additions & 0 deletions custodian/vasp/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest

from custodian.utils import tracked_lru_cache
from custodian.vasp.io import load_outcar, load_vasprun

test_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "test_files")


class IOTest(unittest.TestCase):
def test_load_outcar(self):
outcar = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))
assert outcar is not None
outcar2 = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))

assert outcar is outcar2

assert len(tracked_lru_cache.cached_functions) == 1

def test_load_vasprun(self):
vr = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))
assert vr is not None
vr2 = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))

assert vr is vr2

assert len(tracked_lru_cache.cached_functions) == 1
7 changes: 4 additions & 3 deletions custodian/vasp/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import os
from collections import deque

from pymatgen.io.vasp import Chgcar, Incar, Outcar, Vasprun
from pymatgen.io.vasp import Chgcar, Incar

from custodian.custodian import Validator
from custodian.vasp.io import load_outcar, load_vasprun


class VasprunXMLValidator(Validator):
Expand All @@ -27,7 +28,7 @@ def __init__(self, output_file="vasp.out", stderr_file="std_err.txt"):
def check(self):
"""Check for error."""
try:
Vasprun("vasprun.xml")
load_vasprun(os.path.join(os.getcwd(), "vasprun.xml"))
except Exception:
exception_context = {}

Expand Down Expand Up @@ -88,7 +89,7 @@ def check(self):
if not is_npt:
return False

outcar = Outcar("OUTCAR")
outcar = load_outcar(os.path.join(os.getcwd(), "OUTCAR"))
patterns = {"MDALGO": r"MDALGO\s+=\s+([\d]+)"}
outcar.read_pattern(patterns=patterns)
if outcar.data["MDALGO"] == [["3"]]:
Expand Down