diff --git a/tests/deeprvat/test_train.py b/tests/deeprvat/test_train.py index 7fbcbbfa..ad0e77cc 100644 --- a/tests/deeprvat/test_train.py +++ b/tests/deeprvat/test_train.py @@ -4,7 +4,6 @@ from deeprvat.data import DenseGTDataset import yaml from typing import Dict, Tuple -import pandas as pd import itertools from torch.utils.data import DataLoader from tqdm import tqdm @@ -36,10 +35,12 @@ # 9. Different min_variant_counts script_dir = Path(__file__).resolve().parent +repo_base_dir = script_dir.parent.parent tests_data_dir = script_dir / "test_data" / "training" example_data_dir = script_dir.parent / "example" test_config_file = tests_data_dir / "config.yaml" + with open(tests_data_dir / "phenotypes.txt", "r") as f: phenotypes = f.read().strip().split("\n") @@ -127,7 +128,7 @@ def test_multiphenodataset(multipheno_data, cache_tensors: bool, batch_size: int list(zip(phenotypes, [0, 1, 2])), ) def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): - # os.chdir(example_data_dir) + os.chdir(repo_base_dir) with open(test_config_file, "r") as f: config = yaml.safe_load(f)