diff --git a/src/aces/aggregate.py b/src/aces/aggregate.py index 86dcbb9..769de1b 100644 --- a/src/aces/aggregate.py +++ b/src/aces/aggregate.py @@ -4,7 +4,12 @@ import polars as pl -from .types import PRED_CNT_TYPE, TemporalWindowBounds, ToEventWindowBounds +from .types import ( + EVENT_INDEX_COLUMN, + PRED_CNT_TYPE, + TemporalWindowBounds, + ToEventWindowBounds, +) def aggregate_temporal_window( @@ -79,7 +84,7 @@ def aggregate_temporal_window( ... "is_C": [1, 1, 0, 0, 1, 0], ... }) >>> aggregate_temporal_window(df, TemporalWindowBounds( - ... True, timedelta(days=7), True, None)) + ... True, timedelta(days=7), True, None)).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -94,7 +99,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 1989-12-10 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... True, timedelta(days=1), True, timedelta(days=0))) + ... True, timedelta(days=1), True, timedelta(days=0))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -109,7 +114,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... True, timedelta(days=1), False, timedelta(days=0))) + ... True, timedelta(days=1), False, timedelta(days=0))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -124,7 +129,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... False, timedelta(days=1), False, timedelta(days=0))) + ... False, timedelta(days=1), False, timedelta(days=0))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -139,7 +144,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... False, timedelta(days=-1), False, timedelta(days=0))) + ... False, timedelta(days=-1), False, timedelta(days=0))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -154,7 +159,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 1989-12-02 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... False, timedelta(hours=12), False, timedelta(hours=12))) + ... False, timedelta(hours=12), False, timedelta(hours=12))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -173,7 +178,7 @@ def aggregate_temporal_window( >>> # the earliest event in the aggregation window, regardless of whether that is earlier than the >>> # timestamp of the row. >>> aggregate_temporal_window(df, ( - ... False, timedelta(days=-1), True, timedelta(days=1))) + ... False, timedelta(days=-1), True, timedelta(days=1))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -188,7 +193,7 @@ def aggregate_temporal_window( │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ >>> aggregate_temporal_window(df, ( - ... True, timedelta(days=-1), False, timedelta(days=1))) + ... True, timedelta(days=-1), False, timedelta(days=1))).drop("_EVENT_INDEX") shape: (6, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -206,7 +211,9 @@ def aggregate_temporal_window( if not isinstance(endpoint_expr, TemporalWindowBounds): endpoint_expr = TemporalWindowBounds(*endpoint_expr) - predicate_cols = [c for c in predicates_df.columns if c not in {"subject_id", "timestamp"}] + predicate_cols = [ + c for c in predicates_df.columns if c not in {"subject_id", "timestamp", EVENT_INDEX_COLUMN} + ] return ( predicates_df.rolling( @@ -216,6 +223,7 @@ def aggregate_temporal_window( ) .agg( *[pl.col(c).sum().cast(PRED_CNT_TYPE).alias(c) for c in predicate_cols], + pl.col(EVENT_INDEX_COLUMN).max(), ) .sort(by=["subject_id", "timestamp"]) .select( @@ -226,6 +234,7 @@ def aggregate_temporal_window( "timestamp_at_end" ), *predicate_cols, + EVENT_INDEX_COLUMN, ) ) @@ -302,11 +311,13 @@ def aggregate_event_bound_window( ... datetime(year=1989, month=12, day=8, hour=16, minute=22), ... datetime(year=1989, month=12, day=10, hour=3, minute=7), # HAS EVENT BOUND ... ], + ... "_EVENT_INDEX": [0, 1, 2, 0, 1, 2, 3, 4], ... "is_A": [1, 0, 1, 1, 1, 1, 0, 0], ... "is_B": [0, 1, 0, 1, 0, 1, 1, 1], ... "is_C": [0, 1, 0, 0, 0, 1, 0, 1], ... }) - >>> aggregate_event_bound_window(df, ToEventWindowBounds(True, "is_C", True, None)) + >>> aggregate_event_bound_window(df, ToEventWindowBounds(True, "is_C", True, None)).drop( + ... "_EVENT_INDEX") shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -322,7 +333,8 @@ def aggregate_event_bound_window( │ 2 ┆ 1989-12-08 16:22:00 ┆ 1989-12-08 16:22:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 2 ┆ 1 │ │ 2 ┆ 1989-12-10 03:07:00 ┆ 1989-12-10 03:07:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 1 ┆ 1 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ - >>> aggregate_event_bound_window(df, ToEventWindowBounds(True, "is_C", False, None)) + >>> aggregate_event_bound_window(df, ToEventWindowBounds(True, "is_C", False, None)).drop( + ... "_EVENT_INDEX") shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -338,7 +350,8 @@ def aggregate_event_bound_window( │ 2 ┆ 1989-12-08 16:22:00 ┆ 1989-12-08 16:22:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 1 ┆ 0 │ │ 2 ┆ 1989-12-10 03:07:00 ┆ null ┆ null ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ - >>> aggregate_event_bound_window(df, ToEventWindowBounds(False, "is_C", True, None)) + >>> aggregate_event_bound_window(df, ToEventWindowBounds(False, "is_C", True, None)).drop( + ... "_EVENT_INDEX") shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -354,7 +367,8 @@ def aggregate_event_bound_window( │ 2 ┆ 1989-12-08 16:22:00 ┆ 1989-12-08 16:22:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 1 ┆ 1 │ │ 2 ┆ 1989-12-10 03:07:00 ┆ 1989-12-10 03:07:00 ┆ 1989-12-10 03:07:00 ┆ 0 ┆ 0 ┆ 0 │ └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ - >>> aggregate_event_bound_window(df, ToEventWindowBounds(True, "is_C", True, timedelta(days=3))) + >>> aggregate_event_bound_window(df, ToEventWindowBounds( + ... True, "is_C", True, timedelta(days=3))).drop("_EVENT_INDEX") shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -495,6 +509,7 @@ def boolean_expr_bound_sum( ... datetime(year=1989, month=12, day=8, hour=16, minute=22), ... datetime(year=1989, month=12, day=10, hour=3, minute=7), # HAS EVENT BOUND ... ], + ... "_EVENT_INDEX": [0, 1, 2, 0, 1, 2, 3, 4], ... "idx": [0, 1, 2, 3, 4, 5, 6, 7], ... "is_A": [1, 0, 1, 1, 1, 1, 0, 0], ... "is_B": [0, 1, 0, 1, 0, 1, 1, 1], @@ -675,7 +690,7 @@ def boolean_expr_bound_sum( ... "bound_to_row", ... "both", ... offset = timedelta(days=3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -697,7 +712,7 @@ def boolean_expr_bound_sum( ... "bound_to_row", ... "left", ... offset = timedelta(days=3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -719,7 +734,7 @@ def boolean_expr_bound_sum( ... "bound_to_row", ... "none", ... timedelta(days=-3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -741,7 +756,7 @@ def boolean_expr_bound_sum( ... "bound_to_row", ... "right", ... offset = timedelta(days=-3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -763,7 +778,7 @@ def boolean_expr_bound_sum( ... "row_to_bound", ... "both", ... offset = timedelta(days=3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -785,7 +800,7 @@ def boolean_expr_bound_sum( ... "row_to_bound", ... "left", ... offset = timedelta(days=3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -807,7 +822,7 @@ def boolean_expr_bound_sum( ... "row_to_bound", ... "none", ... offset = timedelta(days=-3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -829,7 +844,7 @@ def boolean_expr_bound_sum( ... "row_to_bound", ... "right", ... offset = timedelta(days=-3), - ... ).drop("idx") + ... ).drop(["idx", "_EVENT_INDEX"]) shape: (8, 7) ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ @@ -891,7 +906,7 @@ def boolean_expr_bound_sum( ), ) - cols = [c for c in df.columns if c not in {"subject_id", "timestamp"}] + cols = [c for c in df.columns if c not in {"subject_id", "timestamp", EVENT_INDEX_COLUMN}] cumsum_cols = {c: pl.col(c).cum_sum().over("subject_id").alias(f"{c}_cumsum_at_row") for c in cols} df = df.with_columns(*cumsum_cols.values()) @@ -1024,4 +1039,5 @@ def agg_offset_fn(c: str) -> pl.Expr: st_timestamp_expr.alias("timestamp_at_start"), end_timestamp_expr.alias("timestamp_at_end"), *(agg_offset_fn(c).cast(PRED_CNT_TYPE, strict=False).fill_null(0).alias(c) for c in cols), + EVENT_INDEX_COLUMN, ) diff --git a/src/aces/extract_subtree.py b/src/aces/extract_subtree.py index 334f197..bfb3592 100644 --- a/src/aces/extract_subtree.py +++ b/src/aces/extract_subtree.py @@ -8,6 +8,7 @@ from .aggregate import aggregate_event_bound_window, aggregate_temporal_window from .constraints import check_constraints +from .types import EVENT_INDEX_COLUMN, LAST_EVENT_INDEX_COLUMN def extract_subtree( @@ -133,6 +134,7 @@ def extract_subtree( ... datetime(year=1999, month=12, day=6, hour=15, minute=17), # Admission ... datetime(year=1999, month=12, day=6, hour=16, minute=22), # Discharge ... ], + ... "_EVENT_INDEX": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2], ... "is_admission": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0], ... "is_discharge": [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1], ... "is_death": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], @@ -144,18 +146,18 @@ def extract_subtree( ... .rename({"timestamp": "subtree_anchor_timestamp"}) ... ).select("subject_id", "subtree_anchor_timestamp") >>> print(subtreee_anchor_realizations) - shape: (5, 2) - ┌────────────┬──────────────────────────┐ - │ subject_id ┆ subtree_anchor_timestamp │ - │ --- ┆ --- │ - │ i64 ┆ datetime[μs] │ - ╞════════════╪══════════════════════════╡ - │ 1 ┆ 1989-12-03 13:14:00 │ - │ 1 ┆ 1989-12-23 03:12:00 │ - │ 2 ┆ 1983-12-02 12:03:00 │ - │ 2 ┆ 1989-12-06 15:17:00 │ - │ 3 ┆ 1999-12-06 15:17:00 │ - └────────────┴──────────────────────────┘ + shape: (5, 3) + ┌────────────┬──────────────────────────┬──────────────┐ + │ subject_id ┆ subtree_anchor_timestamp ┆ _EVENT_INDEX │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 │ + ╞════════════╪══════════════════════════╪══════════════╡ + │ 1 ┆ 1989-12-03 13:14:00 ┆ 1 │ + │ 1 ┆ 1989-12-23 03:12:00 ┆ 4 │ + │ 2 ┆ 1983-12-02 12:03:00 ┆ 1 │ + │ 2 ┆ 1989-12-06 15:17:00 ┆ 3 │ + │ 3 ┆ 1999-12-06 15:17:00 ┆ 1 │ + └────────────┴──────────────────────────┴──────────────┘ >>> out = extract_subtree(root, subtreee_anchor_realizations, predicates_df, timedelta(0)) >>> out.select( ... "subject_id", @@ -243,7 +245,9 @@ def extract_subtree( └─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘ """ recursive_results = [] - predicate_cols = [c for c in predicates_df.columns if c not in {"subject_id", "timestamp"}] + predicate_cols = [ + c for c in predicates_df.columns if c not in {"subject_id", "timestamp", EVENT_INDEX_COLUMN} + ] if not subtree.children: return subtree_anchor_realizations @@ -326,6 +330,7 @@ def extract_subtree( pl.lit(child.name).alias("window_name"), "timestamp_at_start", "timestamp_at_end", + pl.col(EVENT_INDEX_COLUMN).alias(LAST_EVENT_INDEX_COLUMN), *predicate_cols, ).alias(f"{child.name}_summary"), ) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index c286650..7976840 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -10,6 +10,8 @@ from .types import ( ANY_EVENT_COLUMN, END_OF_RECORD_KEY, + EVENT_INDEX_COLUMN, + EVENT_INDEX_TYPE, PRED_CNT_TYPE, START_OF_RECORD_KEY, ) @@ -485,17 +487,17 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D ... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M" ... }) ... get_predicates_df(config, data_config) - shape: (4, 7) - ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐ - │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ - └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘ + shape: (4, 8) + ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┬──────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT ┆ _EVENT_INDEX │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╪══════════════╡ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 ┆ 1 │ + └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┴──────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... ( @@ -505,17 +507,17 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D ... ) ... data_config = DictConfig({"path": str(data_path), "standard": "direct", "ts_format": None}) ... get_predicates_df(config, data_config) - shape: (4, 7) - ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐ - │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ - └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘ + shape: (4, 8) + ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┬──────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT ┆ _EVENT_INDEX │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╪══════════════╡ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 ┆ 1 │ + └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┴──────────────┘ >>> any_event_trigger = EventConfig("_ANY_EVENT") >>> adm_only_predicates = {"adm": PlainPredicateConfig("adm")} >>> st_end_windows = { @@ -636,4 +638,10 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D logger.info(f"Added predicate column '{END_OF_RECORD_KEY}'.") predicate_cols += special_predicates + # create a column for event_id + data = data.with_columns(pl.lit(1).alias(EVENT_INDEX_COLUMN)) + data = data.with_columns( + (pl.col(EVENT_INDEX_COLUMN).cum_sum().over("subject_id") - 1).cast(EVENT_INDEX_TYPE) + ) + return data diff --git a/src/aces/types.py b/src/aces/types.py index 20db78c..d74fc0a 100644 --- a/src/aces/types.py +++ b/src/aces/types.py @@ -12,6 +12,9 @@ # The type used for final aggregate counts of predicates. PRED_CNT_TYPE = pl.Int64 +# The type used for event indexing +EVENT_INDEX_TYPE = pl.Int64 + # The key used in the endpoint expression to indicate the window should be aggregated to the record start. START_OF_RECORD_KEY = "_RECORD_START" END_OF_RECORD_KEY = "_RECORD_END" @@ -19,6 +22,10 @@ # The key used to capture the count of events of any kind that occur in a window. ANY_EVENT_COLUMN = "_ANY_EVENT" +# The key used for the event index +EVENT_INDEX_COLUMN = "_EVENT_INDEX" +LAST_EVENT_INDEX_COLUMN = "_LAST_EVENT_INDEX" + @dataclasses.dataclass(order=True) class TemporalWindowBounds: diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 41ce398..763feee 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -129,6 +129,7 @@ "window_name": "input.end", "timestamp_at_start": "01/27/1991 23:32", "timestamp_at_end": "01/28/1991 23:32", + "_LAST_EVENT_INDEX": 15, "admission": 0, "discharge": 0, "death": 0, @@ -139,6 +140,7 @@ "window_name": "input.end", "timestamp_at_start": "06/05/1996 00:32", "timestamp_at_end": "06/06/1996 00:32", + "_LAST_EVENT_INDEX": 7, "admission": 0, "discharge": 0, "death": 0, @@ -151,6 +153,7 @@ "window_name": "input.start", "timestamp_at_start": "12/01/1989 12:03", "timestamp_at_end": "01/28/1991 23:32", + "_LAST_EVENT_INDEX": 15, "admission": 2, "discharge": 1, "death": 0, @@ -161,6 +164,7 @@ "window_name": "input.start", "timestamp_at_start": "03/08/1996 02:24", "timestamp_at_end": "06/06/1996 00:32", + "_LAST_EVENT_INDEX": 7, "admission": 2, "discharge": 1, "death": 0, @@ -173,6 +177,7 @@ "window_name": "gap.end", "timestamp_at_start": "01/27/1991 23:32", "timestamp_at_end": "01/29/1991 23:32", + "_LAST_EVENT_INDEX": 16, "admission": 0, "discharge": 0, "death": 0, @@ -183,6 +188,7 @@ "window_name": "gap.end", "timestamp_at_start": "06/05/1996 00:32", "timestamp_at_end": "06/07/1996 00:32", + "_LAST_EVENT_INDEX": 7, "admission": 0, "discharge": 0, "death": 0, @@ -195,6 +201,7 @@ "window_name": "target.end", "timestamp_at_start": "01/29/1991 23:32", "timestamp_at_end": "01/31/1991 02:15", + "_LAST_EVENT_INDEX": 23, "admission": 0, "discharge": 1, "death": 0, @@ -205,6 +212,7 @@ "window_name": "target.end", "timestamp_at_start": "06/07/1996 00:32", "timestamp_at_end": "06/08/1996 03:00", + "_LAST_EVENT_INDEX": 12, "admission": 0, "discharge": 0, "death": 1,