Skip to content

Commit

Permalink
made shard specific file names for window stats storage and made test…
Browse files Browse the repository at this point in the history
…s check file directories
  • Loading branch information
Oufattole committed Aug 27, 2024
1 parent 410468d commit 76a1ca2
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 79 deletions.
5 changes: 3 additions & 2 deletions src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def main(cfg: DictConfig):
logger.warning("Output dataframe is empty; adding an empty patient ID column.")
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias("patient_id"))
result = result.head(0)
if cfg.window_stats_filepath:
result.write_parquet(cfg.window_stats_filepath)
if cfg.window_stats_dir:
Path(cfg.window_stats_filename).parent.mkdir(exist_ok=True, parents=True)
result.write_parquet(cfg.window_stats_filename)
result = get_and_validate_label_schema(result)
pq.write_table(result, cfg.output_filepath)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/aces/configs/aces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ config_path: ${cohort_dir}/${cohort_name}.yaml
# sharded data mode.
output_filepath: ${cohort_dir}/${cohort_name}${data._prefix}.parquet
# Optional path to store the output file with the raw window data.
window_stats_filepath: null
window_stats_dir: null
window_stats_filename: ${window_stats_dir}/${cohort_name}${data._prefix}.parquet

log_dir: ${cohort_dir}/${cohort_name}/.logs

Expand Down
162 changes: 86 additions & 76 deletions tests/test_meds.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,88 +371,98 @@ def test_meds():


def test_meds_window_storage():
import tempfile
from pathlib import Path
from tests.utils import write_input_files, write_task_configs, run_command, assert_df_equal

input_files = MEDS_SHARDS
task_configs = {TASK_NAME: TASK_CFG}
want_outputs_by_task = {TASK_NAME: WANT_SHARDS}
data_standard = "meds"

with tempfile.TemporaryDirectory() as root_dir:
root_dir = Path(root_dir)
data_dir = root_dir / "sample_data" / "data"
cohort_dir = root_dir / "sample_cohort"

wrote_files = write_input_files(data_dir, input_files)
if len(wrote_files) == 0:
raise ValueError("No input files were written.")
elif len(wrote_files) > 1:
import tempfile
from pathlib import Path

from tests.utils import (
assert_df_equal,
run_command,
write_input_files,
write_task_configs,
)

input_files = MEDS_SHARDS
task_configs = {TASK_NAME: TASK_CFG}
want_outputs_by_task = {TASK_NAME: WANT_SHARDS}
data_standard = "meds"

with tempfile.TemporaryDirectory() as root_dir:
root_dir = Path(root_dir)
data_dir = root_dir / "sample_data" / "data"
cohort_dir = root_dir / "sample_cohort"

wrote_files = write_input_files(data_dir, input_files)
assert len(wrote_files) > 1, "No input files were written."
sharded = True
command = "aces-cli --multirun"
else:
sharded = False
command = "aces-cli"

wrote_configs = write_task_configs(cohort_dir, task_configs)
if len(wrote_configs) == 0:
raise ValueError("No task configs were written.")
wrote_configs = write_task_configs(cohort_dir, task_configs)
if len(wrote_configs) == 0:
raise ValueError("No task configs were written.")

for task in task_configs:
if sharded:
for task in task_configs:
want_outputs = {
cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items()
}
else:
want_outputs = {cohort_dir / f"{task}.parquet": want_outputs_by_task[task]}
window_stats_filepath = cohort_dir / f"{task}_window_stats.parquet"
want_window_fp = window_stats_filepath

extraction_config_kwargs = {
"cohort_dir": str(cohort_dir.resolve()),
"cohort_name": task,
"hydra.verbose": True,
"data.standard": data_standard,
"window_stats_filepath": str(window_stats_filepath.resolve()),
}

if len(wrote_files) > 1:
extraction_config_kwargs["data"] = "sharded"
extraction_config_kwargs["data.root"] = str(data_dir.resolve())
extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"'
else:
extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve())

stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}")

try:
if sharded:
out_dir = cohort_dir / task
all_out_fps = list(out_dir.glob("**/*.parquet"))
all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps)
if len(all_out_fps) == 0 and len(want_outputs) > 0:
all_directory_contents = ", ".join(
str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*")
)
window_dir = Path(cohort_dir / "window_stats")
want_window_output_files = [
window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task].keys()
]

extraction_config_kwargs = {
"cohort_dir": str(cohort_dir.resolve()),
"cohort_name": task,
"hydra.verbose": True,
"data.standard": data_standard,
"window_stats_dir": str(window_dir.resolve()),
}

raise AssertionError(
f"No output files found for task '{task}'. Found files: {all_directory_contents}"
if len(wrote_files) > 1:
extraction_config_kwargs["data"] = "sharded"
extraction_config_kwargs["data.root"] = str(data_dir.resolve())
extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"'
else:
extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve())

stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}")

try:
if sharded:
out_dir = cohort_dir / task
all_out_fps = list(out_dir.glob("**/*.parquet"))
all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps)
if len(all_out_fps) == 0 and len(want_outputs) > 0:
all_directory_contents = ", ".join(
str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*")
)

raise AssertionError(
f"No output files found for task '{task}'. Found files: {all_directory_contents}"
)

assert len(all_out_fps) == len(
want_outputs
), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}"

for want_fp, want_df in want_outputs.items():
out_shard = want_fp.relative_to(cohort_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."

got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}"
)

assert len(all_out_fps) == len(
want_outputs
), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}"

for want_fp, want_df in want_outputs.items():
out_shard = want_fp.relative_to(cohort_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."

got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}"
)
assert want_window_fp.is_file(), f"Expected window stats file {want_window_fp} to exist."
except AssertionError as e:
logger.error(f"{stderr}\n{stdout}")
raise AssertionError(f"Error running task '{task}': {e}") from e
assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist."
out_fps = list(window_dir.glob("**/*.parquet"))
assert len(out_fps) == len(
want_window_output_files
), f"Expected {len(want_window_output_files)} window output files, got {len(out_fps)}"

for want_fp in want_window_output_files:
out_shard = want_fp.relative_to(window_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."
got_df = pl.read_parquet(want_fp)
assert "patient_id" in got_df.columns, f"Expected 'patient_id' column in {out_shard}."
except AssertionError as e:
logger.error(f"{stderr}\n{stdout}")
raise AssertionError(f"Error running task '{task}': {e}") from e

0 comments on commit 76a1ca2

Please sign in to comment.