Skip to content

Commit

Permalink
add tests for caching
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Jul 21, 2023
1 parent 856e650 commit 58aca7f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
38 changes: 38 additions & 0 deletions custodian/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

from custodian.utils import tracked_lru_cache


class TrackedLruCacheTest(unittest.TestCase):
def setUp(self):
# clear cache before and after each test to avoid
# unexpected caching from other tests
tracked_lru_cache.cache_clear()

def test_cache_and_clear(self):
n_calls = 0

@tracked_lru_cache
def some_func(x):
nonlocal n_calls
n_calls += 1
return x

assert some_func(1) == 1
assert n_calls == 1
assert some_func(2) == 2
assert n_calls == 2
assert some_func(1) == 1
assert n_calls == 2

assert len(tracked_lru_cache.cached_functions) == 1

tracked_lru_cache.cache_clear()

assert len(tracked_lru_cache.cached_functions) == 0

assert some_func(1) == 1
assert n_calls == 3

def tearDown(self):
tracked_lru_cache.cache_clear()
4 changes: 4 additions & 0 deletions custodian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(self, func):
self.func = functools.lru_cache(func)
functools.update_wrapper(self, func)

# expose standard lru_cache functions
self.cache_info = self.func.cache_info
self.cache_clear = self.func.cache_clear

def __call__(self, *args, **kwargs):
"""
Call the decorated function
Expand Down
27 changes: 27 additions & 0 deletions custodian/vasp/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest

from custodian.utils import tracked_lru_cache
from custodian.vasp.io import load_outcar, load_vasprun

test_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "test_files")


class IOTest(unittest.TestCase):
def test_load_outcar(self):
outcar = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))
assert outcar is not None
outcar2 = load_outcar(os.path.join(test_dir, "large_sigma", "OUTCAR"))

assert outcar is outcar2

assert len(tracked_lru_cache.cached_functions) == 1

def test_load_vasprun(self):
vr = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))
assert vr is not None
vr2 = load_vasprun(os.path.join(test_dir, "large_sigma", "vasprun.xml"))

assert vr is vr2

assert len(tracked_lru_cache.cached_functions) == 1

0 comments on commit 58aca7f

Please sign in to comment.