Skip to content
This repository has been archived by the owner on Apr 11, 2024. It is now read-only.

Commit

Permalink
tmp outputs dir at tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BobaZooba committed Oct 17, 2023
1 parent 77b8cfa commit b152c19
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def llama_lm_collator(llama_tokenizer: PreTrainedTokenizer) -> LMCollator:


@pytest.fixture(scope="session")
def training_arguments() -> TrainingArguments:
def training_arguments(path_to_outputs: str) -> TrainingArguments:
arguments = TrainingArguments(
output_dir="./outputs/",
output_dir=path_to_outputs,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
warmup_steps=50,
Expand Down Expand Up @@ -191,3 +191,9 @@ def path_to_fused_model_local_path(tmp_path_factory: TempPathFactory) -> str:
def path_to_download_result(tmp_path_factory: TempPathFactory) -> str:
path = tmp_path_factory.mktemp("tmp") / "data.jsonl"
return os.path.abspath(path)


@pytest.fixture(scope="session")
def path_to_outputs(tmp_path_factory: TempPathFactory) -> str:
path = tmp_path_factory.mktemp("tmp") / "outputs/"
return os.path.abspath(path)
3 changes: 2 additions & 1 deletion tests/unit/experiments/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_base_experiment_init(monkeypatch: MonkeyPatch, path_to_train_dummy_data
Experiment(config=config)


def test_base_experiment_train(monkeypatch: MonkeyPatch, path_to_train_prepared_dummy_data: str):
def test_base_experiment_train(monkeypatch: MonkeyPatch, path_to_train_prepared_dummy_data: str, path_to_outputs: str):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config = HuggingFaceConfig(
push_to_hub=False,
Expand All @@ -29,6 +29,7 @@ def test_base_experiment_train(monkeypatch: MonkeyPatch, path_to_train_prepared_
save_total_limit=0,
max_steps=2,
tokenizer_name_or_path=LLAMA_TOKENIZER_DIR,
output_dir=path_to_outputs,
)

with patch_from_pretrained_auto_causal_lm(monkeypatch=monkeypatch):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/run/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests.helpers.patches import patch_from_pretrained_auto_causal_lm, patch_trainer_train


def test_train(monkeypatch: MonkeyPatch, path_to_train_prepared_dummy_data: str):
def test_train(monkeypatch: MonkeyPatch, path_to_train_prepared_dummy_data: str, path_to_outputs: str):
config = HuggingFaceConfig(
push_to_hub=False,
deepspeed_stage=0,
Expand All @@ -15,6 +15,7 @@ def test_train(monkeypatch: MonkeyPatch, path_to_train_prepared_dummy_data: str)
save_total_limit=0,
max_steps=2,
tokenizer_name_or_path=LLAMA_TOKENIZER_DIR,
output_dir=path_to_outputs,
)
with patch_from_pretrained_auto_causal_lm(monkeypatch=monkeypatch):
with patch_trainer_train(monkeypatch=monkeypatch):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/trainers/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_get_lm_trainer(
training_arguments: TrainingArguments,
llama_lm_collator: LMCollator,
soda_dataset: SodaDataset,
path_to_outputs: str,
):
trainer_cls = trainers_registry.get(key=enums.Trainers.lm)
trainer = trainer_cls(
Expand Down

0 comments on commit b152c19

Please sign in to comment.