diff --git a/src/aces/__main__.py b/src/aces/__main__.py index b0a4a9e..80f94ab 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -46,7 +46,7 @@ def main(cfg: DictConfig): result = query.query(task_cfg, predicates_df) if cfg.data.standard.lower() == "meds": - result = result.rename(columns={"subject_id": "patient_id"}) + result = result.rename({"subject_id": "patient_id"}) # save results to parquet os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index c286650..bac82d2 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -207,6 +207,56 @@ def direct_load_plain_predicates( ) +def unnest_meds(data: pl.DataFrame) -> pl.DataFrame: + """Unnest MEDS data. + + Single-nested MEDS has a row per patient, with all the observations aggregated into a list of "event" + structs. The events column needs to be exploded and unnested. + + Args: + data: The Polars DataFrame containing the single-nested MEDS data. + + Returns: + The Polars DataFrame with the events column exploded and unnested. + + Example: + >>> data = pl.DataFrame({ + ... "subject_id": [1, 2, 3], + ... "events": [ + ... [{"timestamp": None, "code": "EYE_COLOR//HAZEL", "numerical_value": None}, + ... {"timestamp": None, "code": "HEIGHT", "numerical_value": 160}, + ... {"timestamp": "3/9/1978 00:00", "code": "DOB", "numerical_value": None}], + ... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None}, + ... {"timestamp": None, "code": "HEIGHT", "numerical_value": 175}, + ... {"timestamp": "12/28/1980 00:00", "code": "DOB", "numerical_value": None}], + ... [{"timestamp": None, "code": "EYE_COLOR//BROWN", "numerical_value": None}, + ... {"timestamp": None, "code": "HEIGHT", "numerical_value": 166}, + ... {"timestamp": "12/19/1988 00:00", "code": "DOB", "numerical_value": None}], + ... ], + ... }) + >>> unnest_meds(data) + shape: (9, 4) + ┌────────────┬──────────────────┬──────────────────┬─────────────────┐ + │ subject_id ┆ timestamp ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ str ┆ i64 │ + ╞════════════╪══════════════════╪══════════════════╪═════════════════╡ + │ 1 ┆ null ┆ EYE_COLOR//HAZEL ┆ null │ + │ 1 ┆ null ┆ HEIGHT ┆ 160 │ + │ 1 ┆ 3/9/1978 00:00 ┆ DOB ┆ null │ + │ 2 ┆ null ┆ EYE_COLOR//BROWN ┆ null │ + │ 2 ┆ null ┆ HEIGHT ┆ 175 │ + │ 2 ┆ 12/28/1980 00:00 ┆ DOB ┆ null │ + │ 3 ┆ null ┆ EYE_COLOR//BROWN ┆ null │ + │ 3 ┆ null ┆ HEIGHT ┆ 166 │ + │ 3 ┆ 12/19/1988 00:00 ┆ DOB ┆ null │ + └────────────┴──────────────────┴──────────────────┴─────────────────┘ + """ + logger.info("Found single-nested MEDS data, unnesting...") + + return data.explode("events").unnest("events") + + def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl.DataFrame: """Generate plain predicate columns from a MEDS dataset. @@ -254,6 +304,9 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl .drop_nulls(subset=["subject_id", "timestamp"]) ) + if data.columns == ["subject_id", "events"]: + data = unnest_meds(data) + # generate plain predicate columns logger.info("Generating plain predicate columns...") for name, plain_predicate in predicates.items():