Skip to content

Commit

Permalink
Explode and unnest data, closes #76 (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
justin13601 authored Jul 26, 2024
1 parent 8dd9c8b commit ff5efab
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(cfg: DictConfig):
result = query.query(task_cfg, predicates_df)

if cfg.data.standard.lower() == "meds":
result = result.rename(columns={"subject_id": "patient_id"})
result = result.rename({"subject_id": "patient_id"})

# save results to parquet
os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True)
Expand Down
53 changes: 53 additions & 0 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,56 @@ def direct_load_plain_predicates(
)


def unnest_meds(data: pl.DataFrame) -> pl.DataFrame:
"""Unnest MEDS data.
Single-nested MEDS has a row per patient, with all the observations aggregated into a list of "event"
structs. The events column needs to be exploded and unnested.
Args:
data: The Polars DataFrame containing the single-nested MEDS data.
Returns:
The Polars DataFrame with the events column exploded and unnested.
Example:
>>> data = pl.DataFrame({
... "subject_id": [1, 2, 3],
... "events": [
... [{"timestamp": None, "code": "EYE_COLOR//HAZEL", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 160},
... {"timestamp": "3/9/1978 00:00", "code": "DOB", "numerical_value": None}],
... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 175},
... {"timestamp": "12/28/1980 00:00", "code": "DOB", "numerical_value": None}],
... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None},
... {"timestamp": None, "code": "HEIGHT", "numerical_value": 166},
... {"timestamp": "12/19/1988 00:00", "code": "DOB", "numerical_value": None}],
... ],
... })
>>> unnest_meds(data)
shape: (9, 4)
┌────────────┬──────────────────┬──────────────────┬─────────────────┐
│ subject_id ┆ timestamp ┆ code ┆ numerical_value │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ i64 │
╞════════════╪══════════════════╪══════════════════╪═════════════════╡
│ 1 ┆ null ┆ EYE_COLOR//HAZEL ┆ null │
│ 1 ┆ null ┆ HEIGHT ┆ 160 │
│ 1 ┆ 3/9/1978 00:00 ┆ DOB ┆ null │
│ 2 ┆ null ┆ EYE_COLOR//BROWN ┆ null │
│ 2 ┆ null ┆ HEIGHT ┆ 175 │
│ 2 ┆ 12/28/1980 00:00 ┆ DOB ┆ null │
│ 3 ┆ null ┆ EYE_COLOR//BROWN ┆ null │
│ 3 ┆ null ┆ HEIGHT ┆ 166 │
│ 3 ┆ 12/19/1988 00:00 ┆ DOB ┆ null │
└────────────┴──────────────────┴──────────────────┴─────────────────┘
"""
logger.info("Found single-nested MEDS data, unnesting...")

return data.explode("events").unnest("events")


def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl.DataFrame:
"""Generate plain predicate columns from a MEDS dataset.
Expand Down Expand Up @@ -254,6 +304,9 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
.drop_nulls(subset=["subject_id", "timestamp"])
)

if data.columns == ["subject_id", "events"]:
data = unnest_meds(data)

# generate plain predicate columns
logger.info("Generating plain predicate columns...")
for name, plain_predicate in predicates.items():
Expand Down

0 comments on commit ff5efab

Please sign in to comment.