Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First commit for event index #67

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(cfg: DictConfig):
result = query.query(task_cfg, predicates_df)

if cfg.data.standard.lower() == "meds":
result = result.rename(columns={"subject_id": "patient_id"})
result = result.rename({"subject_id": "patient_id"})

# save results to parquet
os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True)
Expand Down
832 changes: 425 additions & 407 deletions src/aces/aggregate.py

Large diffs are not rendered by default.

33 changes: 19 additions & 14 deletions src/aces/extract_subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .aggregate import aggregate_event_bound_window, aggregate_temporal_window
from .constraints import check_constraints
from .types import EVENT_INDEX_COLUMN, LAST_EVENT_INDEX_COLUMN


def extract_subtree(
Expand Down Expand Up @@ -133,6 +134,7 @@ def extract_subtree(
... datetime(year=1999, month=12, day=6, hour=15, minute=17), # Admission
... datetime(year=1999, month=12, day=6, hour=16, minute=22), # Discharge
... ],
... "_EVENT_INDEX": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2],
... "is_admission": [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0],
... "is_discharge": [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1],
... "is_death": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
Expand All @@ -142,20 +144,20 @@ def extract_subtree(
>>> subtreee_anchor_realizations = (
... predicates_df.filter(pl.col("is_admission") > 0)
... .rename({"timestamp": "subtree_anchor_timestamp"})
... ).select("subject_id", "subtree_anchor_timestamp")
... ).select("subject_id", "subtree_anchor_timestamp", "_EVENT_INDEX")
>>> print(subtreee_anchor_realizations)
shape: (5, 2)
┌────────────┬──────────────────────────┐
│ subject_id ┆ subtree_anchor_timestamp │
│ --- ┆ --- │
│ i64 ┆ datetime[μs] │
╞════════════╪══════════════════════════╡
│ 1 ┆ 1989-12-03 13:14:00 │
│ 1 ┆ 1989-12-23 03:12:00 │
│ 2 ┆ 1983-12-02 12:03:00 │
│ 2 ┆ 1989-12-06 15:17:00 │
│ 3 ┆ 1999-12-06 15:17:00 │
└────────────┴──────────────────────────┘
shape: (5, 3)
┌────────────┬──────────────────────────┬──────────────
│ subject_id ┆ subtree_anchor_timestamp ┆ _EVENT_INDEX
│ --- ┆ --- ┆ ---
│ i64 ┆ datetime[μs] ┆ i64
╞════════════╪══════════════════════════╪══════════════
│ 1 ┆ 1989-12-03 13:14:00 ┆ 1
│ 1 ┆ 1989-12-23 03:12:00 ┆ 4
│ 2 ┆ 1983-12-02 12:03:00 ┆ 1
│ 2 ┆ 1989-12-06 15:17:00 ┆ 3
│ 3 ┆ 1999-12-06 15:17:00 ┆ 1
└────────────┴──────────────────────────┴──────────────
>>> out = extract_subtree(root, subtreee_anchor_realizations, predicates_df, timedelta(0))
>>> out.select(
... "subject_id",
Expand Down Expand Up @@ -243,7 +245,9 @@ def extract_subtree(
└─────────────────────┴─────────────────────┴──────────────┴──────────────┴──────────┴─────────────┘
"""
recursive_results = []
predicate_cols = [c for c in predicates_df.columns if c not in {"subject_id", "timestamp"}]
predicate_cols = [
c for c in predicates_df.columns if c not in {"subject_id", "timestamp", EVENT_INDEX_COLUMN}
]

if not subtree.children:
return subtree_anchor_realizations
Expand Down Expand Up @@ -326,6 +330,7 @@ def extract_subtree(
pl.lit(child.name).alias("window_name"),
"timestamp_at_start",
"timestamp_at_end",
pl.col(EVENT_INDEX_COLUMN).alias(LAST_EVENT_INDEX_COLUMN),
*predicate_cols,
).alias(f"{child.name}_summary"),
)
Expand Down
72 changes: 39 additions & 33 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .types import (
ANY_EVENT_COLUMN,
END_OF_RECORD_KEY,
EVENT_INDEX_COLUMN,
EVENT_INDEX_TYPE,
PRED_CNT_TYPE,
START_OF_RECORD_KEY,
)
Expand Down Expand Up @@ -485,17 +487,17 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M"
... })
... get_predicates_df(config, data_config)
shape: (4, 7)
┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐
│ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │
└────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘
shape: (4, 8)
┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┬──────────────
│ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT ┆ _EVENT_INDEX
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ---
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64
╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╪══════════════
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 1
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 ┆ 1
└────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┴──────────────
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f:
... data_path = Path(f.name)
... (
Expand All @@ -505,17 +507,17 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
... )
... data_config = DictConfig({"path": str(data_path), "standard": "direct", "ts_format": None})
... get_predicates_df(config, data_config)
shape: (4, 7)
┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐
│ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │
└────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘
shape: (4, 8)
┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┬──────────────
│ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT ┆ _EVENT_INDEX
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ---
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64
╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╪══════════════
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 1
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 ┆ 1
└────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┴──────────────
>>> any_event_trigger = EventConfig("_ANY_EVENT")
>>> adm_only_predicates = {"adm": PlainPredicateConfig("adm")}
>>> st_end_windows = {
Expand All @@ -540,17 +542,17 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M"
... })
... get_predicates_df(st_end_config, data_config)
shape: (4, 6)
┌────────────┬─────────────────────┬─────┬────────────┬───────────────┬─────────────┐
│ subject_id ┆ timestamp ┆ adm ┆ _ANY_EVENT ┆ _RECORD_START ┆ _RECORD_END │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═════════════════════╪═════╪════════════╪═══════════════╪═════════════╡
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
└────────────┴─────────────────────┴─────┴────────────┴───────────────┴─────────────┘
shape: (4, 7)
┌────────────┬─────────────────────┬─────┬────────────┬───────────────┬─────────────┬──────────────
│ subject_id ┆ timestamp ┆ adm ┆ _ANY_EVENT ┆ _RECORD_START ┆ _RECORD_END ┆ _EVENT_INDEX
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ ---
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64
╞════════════╪═════════════════════╪═════╪════════════╪═══════════════╪═════════════╪══════════════
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 ┆ 0
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 ┆ 0
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1
└────────────┴─────────────────────┴─────┴────────────┴───────────────┴─────────────┴──────────────
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f:
... data_path = Path(f.name)
... data.write_csv(data_path)
Expand Down Expand Up @@ -636,4 +638,8 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
logger.info(f"Added predicate column '{END_OF_RECORD_KEY}'.")
predicate_cols += special_predicates

# create a column for event_id
data = data.with_columns(
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
pl.int_range(pl.len()).over("subject_id").cast(EVENT_INDEX_TYPE).alias(EVENT_INDEX_COLUMN)
)
return data
7 changes: 7 additions & 0 deletions src/aces/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
# The type used for final aggregate counts of predicates.
PRED_CNT_TYPE = pl.Int64

# The type used for event indexing
EVENT_INDEX_TYPE = pl.UInt64

# The key used in the endpoint expression to indicate the window should be aggregated to the record start.
START_OF_RECORD_KEY = "_RECORD_START"
END_OF_RECORD_KEY = "_RECORD_END"

# The key used to capture the count of events of any kind that occur in a window.
ANY_EVENT_COLUMN = "_ANY_EVENT"

# The key used for the event index
EVENT_INDEX_COLUMN = "_EVENT_INDEX"
LAST_EVENT_INDEX_COLUMN = "_LAST_EVENT_INDEX"


@dataclasses.dataclass(order=True)
class TemporalWindowBounds:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
"window_name": "input.end",
"timestamp_at_start": "01/27/1991 23:32",
"timestamp_at_end": "01/28/1991 23:32",
"_LAST_EVENT_INDEX": 15,
"admission": 0,
"discharge": 0,
"death": 0,
Expand All @@ -139,6 +140,7 @@
"window_name": "input.end",
"timestamp_at_start": "06/05/1996 00:32",
"timestamp_at_end": "06/06/1996 00:32",
"_LAST_EVENT_INDEX": 7,
"admission": 0,
"discharge": 0,
"death": 0,
Expand All @@ -151,6 +153,7 @@
"window_name": "input.start",
"timestamp_at_start": "12/01/1989 12:03",
"timestamp_at_end": "01/28/1991 23:32",
"_LAST_EVENT_INDEX": 15,
"admission": 2,
"discharge": 1,
"death": 0,
Expand All @@ -161,6 +164,7 @@
"window_name": "input.start",
"timestamp_at_start": "03/08/1996 02:24",
"timestamp_at_end": "06/06/1996 00:32",
"_LAST_EVENT_INDEX": 7,
"admission": 2,
"discharge": 1,
"death": 0,
Expand All @@ -173,6 +177,7 @@
"window_name": "gap.end",
"timestamp_at_start": "01/27/1991 23:32",
"timestamp_at_end": "01/29/1991 23:32",
"_LAST_EVENT_INDEX": 16,
"admission": 0,
"discharge": 0,
"death": 0,
Expand All @@ -183,6 +188,7 @@
"window_name": "gap.end",
"timestamp_at_start": "06/05/1996 00:32",
"timestamp_at_end": "06/07/1996 00:32",
"_LAST_EVENT_INDEX": 7,
"admission": 0,
"discharge": 0,
"death": 0,
Expand All @@ -195,6 +201,7 @@
"window_name": "target.end",
"timestamp_at_start": "01/29/1991 23:32",
"timestamp_at_end": "01/31/1991 02:15",
"_LAST_EVENT_INDEX": 23,
"admission": 0,
"discharge": 1,
"death": 0,
Expand All @@ -205,6 +212,7 @@
"window_name": "target.end",
"timestamp_at_start": "06/07/1996 00:32",
"timestamp_at_end": "06/08/1996 03:00",
"_LAST_EVENT_INDEX": 12,
"admission": 0,
"discharge": 0,
"death": 1,
Expand Down
Loading