diff --git a/src/aces/constraints.py b/src/aces/constraints.py index 0f09bb2..5d47907 100644 --- a/src/aces/constraints.py +++ b/src/aces/constraints.py @@ -71,6 +71,23 @@ def check_constraints( │ 2 ┆ 1989-12-01 13:14:00 ┆ 3 ┆ 10 ┆ 1 │ │ 2 ┆ 1989-12-03 15:17:00 ┆ 3 ┆ 2 ┆ 1 │ └────────────┴─────────────────────┴──────┴──────┴──────┘ + >>> predicates_df = pl.DataFrame({ + ... "subject_id": [1, 1, 3], + ... "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)], + ... "A": [False, False, False], + ... "_ANY_EVENT": [True, True, True], + ... }) + >>> check_constraints({"_ANY_EVENT": (1, None)}, predicates_df) + shape: (3, 4) + ┌────────────┬─────────────────────┬───────┬────────────┐ + │ subject_id ┆ timestamp ┆ A ┆ _ANY_EVENT │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ bool ┆ bool │ + ╞════════════╪═════════════════════╪═══════╪════════════╡ + │ 1 ┆ 1980-12-28 00:00:00 ┆ false ┆ true │ + │ 1 ┆ 2010-06-20 00:00:00 ┆ false ┆ true │ + │ 3 ┆ 2010-05-11 00:00:00 ┆ false ┆ true │ + └────────────┴─────────────────────┴───────┴────────────┘ """ should_drop = pl.lit(False) diff --git a/src/aces/extract_subtree.py b/src/aces/extract_subtree.py index 334f197..78f9d1d 100644 --- a/src/aces/extract_subtree.py +++ b/src/aces/extract_subtree.py @@ -140,9 +140,11 @@ def extract_subtree( ... "_ANY_EVENT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ... }) >>> subtreee_anchor_realizations = ( - ... predicates_df.filter(pl.col("is_admission") > 0) + ... predicates_df + ... .filter(pl.col("is_admission") > 0) ... .rename({"timestamp": "subtree_anchor_timestamp"}) - ... ).select("subject_id", "subtree_anchor_timestamp") + ... .select("subject_id", "subtree_anchor_timestamp") + ... ) >>> print(subtreee_anchor_realizations) shape: (5, 2) ┌────────────┬──────────────────────────┐ @@ -157,10 +159,7 @@ def extract_subtree( │ 3 ┆ 1999-12-06 15:17:00 │ └────────────┴──────────────────────────┘ >>> out = extract_subtree(root, subtreee_anchor_realizations, predicates_df, timedelta(0)) - >>> out.select( - ... "subject_id", - ... "subtree_anchor_timestamp", - ... ) + >>> out.select("subject_id", "subtree_anchor_timestamp") shape: (1, 2) ┌────────────┬──────────────────────────┐ │ subject_id ┆ subtree_anchor_timestamp │ @@ -296,7 +295,7 @@ def extract_subtree( child_anchor_realizations = window_summary_df.select( "subject_id", pl.col("child_anchor_timestamp").alias("subtree_anchor_timestamp"), - ) + ).unique() # Step 5: Recurse recursive_result = extract_subtree( diff --git a/src/aces/query.py b/src/aces/query.py index 25805e0..701bf8f 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -31,7 +31,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame ValueError: If the (subject_id, timestamp) columns are not unique. Examples: - These examples just show the error cases for now; see the `tests` directory for full examples. + These examples are limited for now; see the `tests` directory for full examples. >>> cfg = None # This is obviously invalid, but we're just testing the error case. >>> predicates_df = {"subject_id": [1, 1], "timestamp": [1, 1]} >>> query(cfg, predicates_df) @@ -42,6 +42,33 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame Traceback (most recent call last): ... ValueError: The (subject_id, timestamp) columns must be unique. + >>> from datetime import datetime + >>> from .config import PlainPredicateConfig, WindowConfig, EventConfig + >>> cfg = TaskExtractorConfig( + ... predicates={"A": PlainPredicateConfig("A")}, + ... trigger=EventConfig("_ANY_EVENT"), + ... windows={ + ... "pre": WindowConfig(None, "trigger", True, False), + ... "post": WindowConfig("pre.end", None, True, True), + ... }, + ... ) + >>> predicates_df = pl.DataFrame({ + ... "subject_id": [1, 1, 3], + ... "timestamp": [datetime(1980, 12, 28), datetime(2010, 6, 20), datetime(2010, 5, 11)], + ... "A": [False, False, False], + ... "_ANY_EVENT": [True, True, True], + ... }) + >>> query(cfg, predicates_df).select("subject_id", "trigger") + shape: (3, 2) + ┌────────────┬─────────────────────┐ + │ subject_id ┆ trigger │ + │ --- ┆ --- │ + │ i64 ┆ datetime[μs] │ + ╞════════════╪═════════════════════╡ + │ 1 ┆ 1980-12-28 00:00:00 │ + │ 1 ┆ 2010-06-20 00:00:00 │ + │ 3 ┆ 2010-05-11 00:00:00 │ + └────────────┴─────────────────────┘ """ if not isinstance(predicates_df, pl.DataFrame): raise TypeError(f"Predicates dataframe type must be a polars.DataFrame. Got: {type(predicates_df)}.")