Skip to content

Commit

Permalink
Merge pull request #1 from mccle/patch-1
Browse files Browse the repository at this point in the history
Update track.py
  • Loading branch information
CPBridge authored Jan 11, 2024
2 parents 1f5bd4e + f99e81a commit 3a5ac29
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 3 deletions.
45 changes: 42 additions & 3 deletions src/pycrumbs/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import random
import sys
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Union, Sequence, cast
from typing import Any, Callable, Dict, Optional, Union, Sequence, cast, List
from uuid import uuid4

import git
Expand Down Expand Up @@ -337,6 +337,7 @@ def tracked(
include_package_inventory: bool = True,
create_parents: bool = False,
require_empty_directory: bool = False,
chain_records: bool = False
) -> Callable:
"""Store information about a function call to disc.
Expand Down Expand Up @@ -449,6 +450,10 @@ def tracked(
directory_parameter and subdirectory_name_parameter are not specified
so that the called decorated function is aware of the location of the
output directory. It may be used in other situations for convenience.
chain_records: bool
If True, a pre-existing record file will have a new record appended to
it within the same file. If False, a pre-existing record file will be
overwritten.
Examples
--------
Expand Down Expand Up @@ -839,7 +844,30 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
}

full_record_path = record_dir / record_name_local
write_record(full_record_path, record=record)

if not full_record_path.name.endswith(".json"):
full_record_path = full_record_path.with_name(
full_record_path.stem + ".json"
)

chaining = False
if chain_records:
if full_record_path.exists():
chaining = True
with full_record_path.open("r") as jf:
previous_record = json.load(jf)

if isinstance(previous_record, List):
out_record = previous_record + [record]

else:
out_record = [previous_record, record]
else:
out_record = record
else:
out_record = record

write_record(full_record_path, record=out_record)

# Run the function as normal
result = function(*bound_args.args, **bound_args.kwargs)
Expand All @@ -848,7 +876,18 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
end_time = datetime.datetime.now()
record['timing']['end_time'] = str(end_time)
record['timing']['run_time'] = str(end_time - start_time)
write_record(full_record_path, record=record)

if chaining:
if isinstance(previous_record, List):
out_record = previous_record + [record]

else:
out_record = [previous_record, record]

else:
out_record = record

write_record(full_record_path, record=out_record)

return result

Expand Down
39 changes: 39 additions & 0 deletions tests/test_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
import tempfile
from uuid import uuid4
from typing import Dict, List

import pycrumbs
from pycrumbs import track
Expand Down Expand Up @@ -780,3 +781,41 @@ def some_fun(x):
some_fun(0)
assert subdir.exists()
assert subdir.joinpath('some_fun_record.json').exists()


def test_record_chaining():
"""Test tracked with chain_records is True."""
with tempfile.TemporaryDirectory() as temp:
temp = Path(temp).resolve()

@pycrumbs.tracked(literal_directory=temp, chain_records=True)
def some_fun(x):
pass

saved_record = temp.joinpath('some_fun_record.json')
# If a record does not yet exist, demonstrate regular behavior.
assert not saved_record.exists()
some_fun(0)
assert saved_record.exists()
with saved_record.open('r') as jf:
record_0 = json.load(jf)
assert isinstance(record_0, Dict)

# If record already exists with 1 element, demonstrate conversion to
# list which contains the original record as its first element.
some_fun(1)
with saved_record.open('r') as jf:
record_1 = json.load(jf)
assert isinstance(record_1, List)
assert len(record_1) == 2
assert record_0 == record_1[0]

# If record already exists as a list of 2+ elements, demonstrate
# original elements match previous records and contain a new final element.
some_fun(2)
with saved_record.open('r') as jf:
record_2 = json.load(jf)
assert isinstance(record_2, List)
assert len(record_2) == 3
assert record_0 == record_1[0] == record_2[0]
assert record_1[1] == record_2[1]

0 comments on commit 3a5ac29

Please sign in to comment.