Skip to content

Commit

Permalink
added thorough tests for checking the actual outputted windows when s…
Browse files Browse the repository at this point in the history
…toring window stats
  • Loading branch information
Oufattole committed Aug 28, 2024
1 parent 353d8a5 commit 41f8bde
Showing 1 changed file with 237 additions and 67 deletions.
304 changes: 237 additions & 67 deletions tests/test_meds.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,175 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
"""
)

WANT_EMPTY_WINDOW_SCHEMA = {"patient_id": pl.Int64}
WANT_NON_EMPTY_WINDOW_SCHEMA = {
"patient_id": pl.UInt32,
"prediction_time": pl.Datetime,
"boolean_value": pl.Int64,
"trigger": pl.Datetime,
"input.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"input.start_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"gap.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"target.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
}

WANT_TRAIN_WINDOW_DATA = """
[
{
"patient_id": 4,
"prediction_time": "1991-01-28 23:32:00",
"boolean_value": 0,
"trigger": "1991-01-27 23:32:00",
"input.end_summary": {
"window_name": "input.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4
},
"input.start_summary": {
"window_name": "input.start",
"timestamp_at_start": "1989-12-01 12:03:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16
},
"gap.end_summary": {
"window_name": "gap.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-29 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5
},
"target.end_summary": {
"window_name": "target.end",
"timestamp_at_start": "1991-01-29 23:32:00",
"timestamp_at_end": "1991-01-31 02:15:00",
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7
}
}
]
"""

WANT_HELD_OUT_WINDOW_DATA = """
[
{
"patient_id": 1,
"prediction_time": "1991-01-28 23:32:00",
"boolean_value": 0,
"trigger": "1991-01-27 23:32:00",
"input.end_summary": {
"window_name": "input.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4
},
"input.start_summary": {
"window_name": "input.start",
"timestamp_at_start": "1989-12-01 12:03:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16
},
"gap.end_summary": {
"window_name": "gap.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-29 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5
},
"target.end_summary": {
"window_name": "target.end",
"timestamp_at_start": "1991-01-29 23:32:00",
"timestamp_at_end": "1991-01-31 02:15:00",
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7
}
}
]
"""


WANT_WINDOW_SHARDS = {
"train/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"train/1.parquet": pl.read_json(StringIO(WANT_TRAIN_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA),
"held_out/0/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"empty_shard.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"held_out.parquet": pl.read_json(
StringIO(WANT_HELD_OUT_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA
),
}


def test_meds():
cli_test(
Expand All @@ -380,7 +549,7 @@ def test_meds():

def test_meds_window_storage():
input_files = MEDS_SHARDS
task_configs = {TASK_NAME: TASK_CFG}
task = TASK_NAME
want_outputs_by_task = {TASK_NAME: WANT_SHARDS}
data_standard = "meds"

Expand All @@ -394,73 +563,74 @@ def test_meds_window_storage():
sharded = True
command = "aces-cli --multirun"

wrote_configs = write_task_configs(cohort_dir, task_configs)
wrote_configs = write_task_configs(cohort_dir, {TASK_NAME: TASK_CFG})
if len(wrote_configs) == 0:
raise ValueError("No task configs were written.")

for task in task_configs:
want_outputs = {
cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items()
}
window_dir = Path(cohort_dir / "window_stats")
want_window_output_files = [
window_dir / task / f"{n}.parquet" for n in want_outputs_by_task[task]
]

extraction_config_kwargs = {
"cohort_dir": str(cohort_dir.resolve()),
"cohort_name": task,
"hydra.verbose": True,
"data.standard": data_standard,
"window_stats_dir": str(window_dir.resolve()),
}

if len(wrote_files) > 1:
extraction_config_kwargs["data"] = "sharded"
extraction_config_kwargs["data.root"] = str(data_dir.resolve())
extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"'
else:
extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve())

stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}")

try:
if sharded:
out_dir = cohort_dir / task
all_out_fps = list(out_dir.glob("**/*.parquet"))
all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps)
if len(all_out_fps) == 0 and len(want_outputs) > 0:
all_directory_contents = ", ".join(
str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*")
)

raise AssertionError(
f"No output files found for task '{task}'. Found files: {all_directory_contents}"
)

assert len(all_out_fps) == len(
want_outputs
), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}"

for want_fp, want_df in want_outputs.items():
out_shard = want_fp.relative_to(cohort_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."

got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}"
want_outputs = {
cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items()
}
window_dir = Path(cohort_dir / "window_stats")
want_window_outputs = {
window_dir / task / filename: want_df for filename, want_df in WANT_WINDOW_SHARDS.items()
}

extraction_config_kwargs = {
"cohort_dir": str(cohort_dir.resolve()),
"cohort_name": task,
"hydra.verbose": True,
"data.standard": data_standard,
"window_stats_dir": str(window_dir.resolve()),
}

if len(wrote_files) > 1:
extraction_config_kwargs["data"] = "sharded"
extraction_config_kwargs["data.root"] = str(data_dir.resolve())
extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"'
else:
extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve())

stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}")

try:
if sharded:
out_dir = cohort_dir / task
all_out_fps = list(out_dir.glob("**/*.parquet"))
all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps)
if len(all_out_fps) == 0 and len(want_outputs) > 0:
all_directory_contents = ", ".join(
str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*")
)

raise AssertionError(
f"No output files found for task '{task}'. Found files: {all_directory_contents}"
)
assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist."
out_fps = list(window_dir.glob("**/*.parquet"))
assert len(out_fps) == len(
want_window_output_files
), f"Expected {len(want_window_output_files)} window output files, got {len(out_fps)}"

for want_fp in want_window_output_files:
out_shard = want_fp.relative_to(window_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."
got_df = pl.read_parquet(want_fp)
assert "patient_id" in got_df.columns, f"Expected 'patient_id' column in {out_shard}."
except AssertionError as e:
logger.error(f"{stderr}\n{stdout}")
raise AssertionError(f"Error running task '{task}': {e}") from e

assert len(all_out_fps) == len(
want_outputs
), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}"

for want_fp, want_df in want_outputs.items():
out_shard = want_fp.relative_to(cohort_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."

got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}"
)
assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist."
out_fps = list(window_dir.glob("**/*.parquet"))
assert len(out_fps) == len(
want_window_outputs
), f"Expected {len(want_window_outputs)} window output files, got {len(out_fps)}"

for want_fp, want_df in want_window_outputs.items():
out_shard = want_fp.relative_to(window_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."
got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for window shard '{out_shard}':\n{want_df}\n{got_df}"
)
except AssertionError as e:
logger.error(f"{stderr}\n{stdout}")
raise AssertionError(f"Error running task '{task}': {e}") from e

0 comments on commit 41f8bde

Please sign in to comment.