Skip to content

Commit

Permalink
Add check for staged git changes before running checkout in `payu che…
Browse files Browse the repository at this point in the history
…ckout`

- Add metadata config for model name (if it is different from model driver name)
  • Loading branch information
Jo Basevi committed Dec 20, 2023
1 parent 228e427 commit 77940de
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
8 changes: 8 additions & 0 deletions payu/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def checkout_branch(self,
new_branch: bool = False,
start_point: Optional[str] = None) -> None:
"""Checkout branch and create branch if specified"""
# First check for staged changes
if self.repo.is_dirty(index=True, working_tree=False):
raise PayuBranchError(
"There are staged git changes. Please stash or commit them "
"before running the checkout command again.\n"
"To see what files are staged, run: git status"
)

# Existing branches
local_branches = self.local_branches_dict().keys()
remote_branches = self.remote_branches_dict()
Expand Down
28 changes: 12 additions & 16 deletions payu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Union
from typing import Optional, Union

from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
Expand Down Expand Up @@ -67,17 +67,17 @@ def __init__(self,
config_path: Optional[Path] = None,
branch: Optional[str] = None,
control_path: Optional[Path] = None) -> None:
self.lab_archive_path = laboratory_archive_path
self.config = read_config(config_path)
self.metadata_config = self.config.get('metadata', {})

if control_path is None:
control_path = Path(self.config.get("control_path"))
self.control_path = control_path
self.filepath = self.control_path / METADATA_FILENAME
self.lab_archive_path = laboratory_archive_path

# Check for metadata flag - Default True
metadata_config = self.config.get('metadata', {})
self.enabled = metadata_config.get('enable', True)
# Config flag to disable creating metadata files and UUIDs
self.enabled = self.metadata_config.get('enable', True)

if self.enabled:
self.repo = GitRepository(self.control_path, catch_error=True)
Expand Down Expand Up @@ -121,9 +121,8 @@ def setup(self,
if not self.enabled:
# Set experiment name only - either configured or includes branch
self.set_experiment_name(ignore_uuid=True)
return

if self.uuid is not None and (keep_uuid or not is_new_experiment):
elif self.uuid is not None and (keep_uuid or not is_new_experiment):
self.set_experiment_name(keep_uuid=keep_uuid,
is_new_experiment=is_new_experiment)
else:
Expand All @@ -135,7 +134,7 @@ def setup(self,

self.archive_path = self.lab_archive_path / self.experiment_name

def new_experiment_name(self, ignore_uuid: bool = False) -> Path:
def new_experiment_name(self, ignore_uuid: bool = False) -> str:
"""Generate a new experiment name"""
if self.branch is None:
self.branch = self.repo.get_branch_name()
Expand Down Expand Up @@ -291,11 +290,9 @@ def update_file(self,

def get_model_name(self) -> str:
"""Get model name from config file"""
model_name = self.config.get('model')
if model_name == 'access':
# TODO: Is access used for anything other than ACCESS-ESM1-5?
# If so, won't set anything here.
model_name = 'ACCESS-ESM1-5'
# Use model name unless specific model is specified in metadata config
default_model_name = self.config.get('model')
model_name = self.metadata_config.get('model', default_model_name)
return model_name.upper()

def get_parent_experiment(self, prior_restart_path: Path) -> None:
Expand Down Expand Up @@ -326,9 +323,8 @@ def commit_file(self) -> None:

def copy_to_archive(self) -> None:
"""Copy metadata file to archive"""
archive_path = self.lab_archive_path / self.experiment_name
mkdir_p(archive_path)
shutil.copy(self.filepath, archive_path / METADATA_FILENAME)
mkdir_p(self.archive_path)
shutil.copy(self.filepath, self.archive_path / METADATA_FILENAME)
# Note: The existence of an archive is used for determining
# experiment names and whether to generate a new UUID

Expand Down
22 changes: 20 additions & 2 deletions test/test_git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_git_checkout_new_branch_existing():
repo = create_new_repo(repo_path)
existing_branch = repo.active_branch

# Test create branch with existing branch
# Test checkout branch with existing branch
repo = GitRepository(repo_path)
with pytest.raises(PayuBranchError):
repo.checkout_branch(str(existing_branch),
Expand All @@ -153,12 +153,30 @@ def test_git_checkout_non_existent_branch():
repo_path = tmpdir / 'remoteRepo'
create_new_repo(repo_path)

# Test create branch with non-existent branch
# Test checkout branch with non-existent branch
repo = GitRepository(repo_path)
with pytest.raises(PayuBranchError):
repo.checkout_branch("Gibberish")


def test_git_checkout_staged_changes():
# Setup
repo_path = tmpdir / 'remoteRepo'
create_new_repo(repo_path)

repo = GitRepository(repo_path)
file_path = repo_path / 'newTestFile.txt'
file_path.touch()

# Test checkout branch works with untracked files
repo.checkout_branch(new_branch=True, branch_name="NewBranch")

# Test checkout raises error with staged changes
repo.repo.index.add([file_path])
with pytest.raises(PayuBranchError):
repo.checkout_branch(new_branch=True, branch_name="NewBranch2")


def test_git_checkout_existing_branch():
# Setup
remote_repo_path = tmpdir / 'remoteRepo'
Expand Down

0 comments on commit 77940de

Please sign in to comment.