Skip to content

Commit

Permalink
Merge branch 'main' into 62_improve_testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 24, 2024
2 parents 5c2283b + 3ec94ea commit 4cba9b6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 61 deletions.
17 changes: 17 additions & 0 deletions src/aces/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def check_constraints(
│ 2 ┆ 1989-12-01 13:14:00 ┆ 3 ┆ 10 ┆ 1 │
│ 2 ┆ 1989-12-03 15:17:00 ┆ 3 ┆ 2 ┆ 1 │
└────────────┴─────────────────────┴──────┴──────┴──────┘
>>> predicates_df = pl.DataFrame({
... "subject_id": [1, 1, 3],
... "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)],
... "A": [False, False, False],
... "_ANY_EVENT": [True, True, True],
... })
>>> check_constraints({"_ANY_EVENT": (1, None)}, predicates_df)
shape: (3, 4)
┌────────────┬─────────────────────┬───────┬────────────┐
│ subject_id ┆ timestamp ┆ A ┆ _ANY_EVENT │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ bool ┆ bool │
╞════════════╪═════════════════════╪═══════╪════════════╡
│ 1 ┆ 1980-12-28 00:00:00 ┆ false ┆ true │
│ 1 ┆ 2010-06-20 00:00:00 ┆ false ┆ true │
│ 3 ┆ 2010-05-11 00:00:00 ┆ false ┆ true │
└────────────┴─────────────────────┴───────┴────────────┘
"""

should_drop = pl.lit(False)
Expand Down
13 changes: 6 additions & 7 deletions src/aces/extract_subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def extract_subtree(
... "_ANY_EVENT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
... })
>>> subtreee_anchor_realizations = (
... predicates_df.filter(pl.col("is_admission") > 0)
... predicates_df
... .filter(pl.col("is_admission") > 0)
... .rename({"timestamp": "subtree_anchor_timestamp"})
... ).select("subject_id", "subtree_anchor_timestamp")
... .select("subject_id", "subtree_anchor_timestamp")
... )
>>> print(subtreee_anchor_realizations)
shape: (5, 2)
┌────────────┬──────────────────────────┐
Expand All @@ -157,10 +159,7 @@ def extract_subtree(
│ 3 ┆ 1999-12-06 15:17:00 │
└────────────┴──────────────────────────┘
>>> out = extract_subtree(root, subtreee_anchor_realizations, predicates_df, timedelta(0))
>>> out.select(
... "subject_id",
... "subtree_anchor_timestamp",
... )
>>> out.select("subject_id", "subtree_anchor_timestamp")
shape: (1, 2)
┌────────────┬──────────────────────────┐
│ subject_id ┆ subtree_anchor_timestamp │
Expand Down Expand Up @@ -296,7 +295,7 @@ def extract_subtree(
child_anchor_realizations = window_summary_df.select(
"subject_id",
pl.col("child_anchor_timestamp").alias("subtree_anchor_timestamp"),
)
).unique()

# Step 5: Recurse
recursive_result = extract_subtree(
Expand Down
53 changes: 0 additions & 53 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,56 +223,6 @@ def direct_load_plain_predicates(
)


def unnest_meds(data: pl.DataFrame) -> pl.DataFrame:
"""Unnest MEDS data.
Single-nested MEDS has a row per patient, with all the observations aggregated into a list of "event"
structs. The events column needs to be exploded and unnested.
Args:
data: The Polars DataFrame containing the single-nested MEDS data.
Returns:
The Polars DataFrame with the events column exploded and unnested.
Example:
>>> data = pl.DataFrame({
... "subject_id": [1, 2, 3],
... "events": [
... [{"timestamp": None, "code": "EYE_COLOR//HAZEL", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 160},
... {"timestamp": "3/9/1978 00:00", "code": "DOB", "numerical_value": None}],
... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 175},
... {"timestamp": "12/28/1980 00:00", "code": "DOB", "numerical_value": None}],
... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 166},
... {"timestamp": "12/19/1988 00:00", "code": "DOB", "numerical_value": None}],
... ],
... })
>>> unnest_meds(data)
shape: (9, 4)
┌────────────┬──────────────────┬──────────────────┬─────────────────┐
│ subject_id ┆ timestamp ┆ code ┆ numerical_value │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ i64 │
╞════════════╪══════════════════╪══════════════════╪═════════════════╡
│ 1 ┆ null ┆ EYE_COLOR//HAZEL ┆ null │
│ 1 ┆ null ┆ HEIGHT ┆ 160 │
│ 1 ┆ 3/9/1978 00:00 ┆ DOB ┆ null │
│ 2 ┆ null ┆ EYE_COLOR//BROWN ┆ null │
│ 2 ┆ null ┆ HEIGHT ┆ 175 │
│ 2 ┆ 12/28/1980 00:00 ┆ DOB ┆ null │
│ 3 ┆ null ┆ EYE_COLOR//BROWN ┆ null │
│ 3 ┆ null ┆ HEIGHT ┆ 166 │
│ 3 ┆ 12/19/1988 00:00 ┆ DOB ┆ null │
└────────────┴──────────────────┴──────────────────┴─────────────────┘
"""
logger.info("Found single-nested MEDS data, unnesting...")

return data.explode("events").unnest("events")


def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl.DataFrame:
"""Generate plain predicate columns from a MEDS dataset.
Expand Down Expand Up @@ -318,9 +268,6 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
logger.info("Loading MEDS data...")
data = pl.read_parquet(data_path).rename({"patient_id": "subject_id", "time": "timestamp"})

if data.columns == ["subject_id", "events"]:
data = unnest_meds(data)

# generate plain predicate columns
logger.info("Generating plain predicate columns...")
for name, plain_predicate in predicates.items():
Expand Down
29 changes: 28 additions & 1 deletion src/aces/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
ValueError: If the (subject_id, timestamp) columns are not unique.
Examples:
These examples just show the error cases for now; see the `tests` directory for full examples.
These examples are limited for now; see the `tests` directory for full examples.
>>> cfg = None # This is obviously invalid, but we're just testing the error case.
>>> predicates_df = {"subject_id": [1, 1], "timestamp": [1, 1]}
>>> query(cfg, predicates_df)
Expand All @@ -42,6 +42,33 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
Traceback (most recent call last):
...
ValueError: The (subject_id, timestamp) columns must be unique.
>>> from datetime import datetime
>>> from .config import PlainPredicateConfig, WindowConfig, EventConfig
>>> cfg = TaskExtractorConfig(
... predicates={"A": PlainPredicateConfig("A")},
... trigger=EventConfig("_ANY_EVENT"),
... windows={
... "pre": WindowConfig(None, "trigger", True, False),
... "post": WindowConfig("pre.end", None, True, True),
... },
... )
>>> predicates_df = pl.DataFrame({
... "subject_id": [1, 1, 3],
... "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)],
... "A": [False, False, False],
... "_ANY_EVENT": [True, True, True],
... })
>>> query(cfg, predicates_df).select("subject_id", "trigger")
shape: (3, 2)
┌────────────┬─────────────────────┐
│ subject_id ┆ trigger │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪═════════════════════╡
│ 1 ┆ 1980-12-28 00:00:00 │
│ 1 ┆ 2010-06-20 00:00:00 │
│ 3 ┆ 2010-05-11 00:00:00 │
└────────────┴─────────────────────┘
"""
if not isinstance(predicates_df, pl.DataFrame):
raise TypeError(f"Predicates dataframe type must be a polars.DataFrame. Got: {type(predicates_df)}.")
Expand Down

0 comments on commit 4cba9b6

Please sign in to comment.