Skip to content

Commit

Permalink
Merge pull request #93 from justin13601/more-flexible-predicates
Browse files Browse the repository at this point in the history
TODO: need to add support for these fields (#87, #90)
  • Loading branch information
mmcdermott authored Aug 10, 2024
2 parents 74d0abf + 95f518c commit a1159b4
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
3 changes: 2 additions & 1 deletion sample_configs/inhospital_mortality.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Task: 24-hour In-hospital Mortality Prediction
predicates:
admission:
code: event_type//ADMISSION
code:
any: ["event_type//ADMISSION", "event_type//DISCHARGE"]
discharge:
code: event_type//DISCHARGE
death:
Expand Down
129 changes: 109 additions & 20 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@

@dataclasses.dataclass
class PlainPredicateConfig:
code: str
code: str | dict
value_min: float | None = None
value_max: float | None = None
value_min_inclusive: bool | None = None
value_max_inclusive: bool | None = None
static: bool = False
other_cols: dict[str, str] = field(default_factory=dict)

def MEDS_eval_expr(self) -> pl.Expr:
"""Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset.
Expand All @@ -60,19 +61,92 @@ def MEDS_eval_expr(self) -> pl.Expr:
>>> expr = cfg.MEDS_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
[(col("code")) == (String(BP//diastolic))]
>>> cfg = PlainPredicateConfig("BP//diastolic", other_cols={"chamber": "atrial"})
>>> expr = cfg.MEDS_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
[(col("code")) == (String(BP//diastolic))].all_horizontal([[(col("chamber")) ==
(String(atrial))]])
>>> cfg = PlainPredicateConfig(code={'regex': None, 'any': None})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Only one of 'regex' or 'any' can be specified in the code field!
Got: ['regex', 'any'].
>>> cfg = PlainPredicateConfig(code={'foo': None})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'foo': None}.
Expected one of 'regex', 'any'.
>>> cfg = PlainPredicateConfig(code={'regex': ''})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'regex': ''}.
Expected a non-empty string for 'regex'.
>>> cfg = PlainPredicateConfig(code={'any': []})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Invalid specification in the code field! Got: {'any': []}.
Expected a list of strings for 'any'.
>>> cfg = PlainPredicateConfig(code={'regex': '^foo.*'})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
col("code").str.contains([String(^foo.*)])
>>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']})
>>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
col("code").is_in([Series])
"""
criteria = [pl.col("code") == self.code]
criteria = []
if isinstance(self.code, dict):
if len(self.code) > 1:
raise ValueError(
"Only one of 'regex' or 'any' can be specified in the code field! "
f"Got: {list(self.code.keys())}."
)

if self.value_min is not None:
if self.value_min_inclusive:
criteria.append(pl.col("numerical_value") >= self.value_min)
else:
criteria.append(pl.col("numerical_value") > self.value_min)
if self.value_max is not None:
if self.value_max_inclusive:
criteria.append(pl.col("numerical_value") <= self.value_max)
if "regex" in self.code:
if not self.code["regex"] or not isinstance(self.code["regex"], str):
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
"Expected a non-empty string for 'regex'."
)
criteria.append(pl.col("code").str.contains(self.code["regex"]))
elif "any" in self.code:
if not self.code["any"] or not isinstance(self.code["any"], list):
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
f"Expected a list of strings for 'any'."
)
criteria.append(pl.Expr.is_in(pl.col("code"), self.code["any"]))
else:
criteria.append(pl.col("numerical_value") < self.value_max)
raise ValueError(
"Invalid specification in the code field! "
f"Got: {self.code}. "
"Expected one of 'regex', 'any'."
)
else:
criteria.append(pl.col("code") == self.code)

if self.value_min is not None:
if self.value_min_inclusive:
criteria.append(pl.col("numerical_value") >= self.value_min)
else:
criteria.append(pl.col("numerical_value") > self.value_min)
if self.value_max is not None:
if self.value_max_inclusive:
criteria.append(pl.col("numerical_value") <= self.value_max)
else:
criteria.append(pl.col("numerical_value") < self.value_max)

if self.other_cols:
criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()])

if len(criteria) == 1:
return criteria[0]
Expand Down Expand Up @@ -120,6 +194,9 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
>>> expr = PlainPredicateConfig("BP").ESGPT_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
col("BP").is_not_null()
>>> expr = PlainPredicateConfig("BP//systole", other_cols={"chamber": "atrial"}).ESGPT_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
[(col("BP")) == (String(systole))].all_horizontal([[(col("chamber")) == (String(atrial))]])
>>> expr = PlainPredicateConfig("BP//systolic", value_min=120).ESGPT_eval_expr()
Traceback (most recent call last):
...
Expand Down Expand Up @@ -167,6 +244,9 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
else:
criteria.append(pl.col(values_column) < self.value_max)

if self.other_cols:
criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()])

if len(criteria) == 1:
return criteria[0]
else:
Expand Down Expand Up @@ -860,19 +940,22 @@ class TaskExtractorConfig:
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False),
static=False,
other_cols={}),
'discharge': PlainPredicateConfig(code='discharge',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False),
static=False,
other_cols={}),
'death': PlainPredicateConfig(code='death',
value_min=None,
value_max=None,
value_min_inclusive=None,
value_max_inclusive=None,
static=False)}
static=False,
other_cols={})}
>>> print(config.label_window) # doctest: +NORMALIZE_WHITESPACE
target
Expand Down Expand Up @@ -1084,17 +1167,21 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'")

logger.info("Parsing predicates...")
predicates = {
n: DerivedPredicateConfig(**p) if "expr" in p else PlainPredicateConfig(**p)
for n, p in predicates.items()
}
predicate_objs = {}
for n, p in predicates.items():
if "expr" in p:
predicate_objs[n] = DerivedPredicateConfig(**p)
else:
config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__}
other_cols = {k: v for k, v in p.items() if k not in config_data}
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)

if patient_demographics:
logger.info("Parsing patient demographics...")
patient_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items()
}
predicates.update(patient_demographics)
predicate_objs.update(patient_demographics)

logger.info("Parsing trigger event...")
trigger = EventConfig(trigger)
Expand All @@ -1108,7 +1195,9 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
else:
windows = {n: WindowConfig(**w) for n, w in windows.items()}

return cls(predicates=predicates, trigger=trigger, windows=windows)
print(predicate_objs)

return cls(predicates=predicate_objs, trigger=trigger, windows=windows)

def save(self, config_path: str | Path, do_overwrite: bool = False):
"""Load a configuration file from the given path and return it as a dict.
Expand Down
1 change: 1 addition & 0 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
# generate plain predicate columns
logger.info("Generating plain predicate columns...")
for name, plain_predicate in predicates.items():
data = data.with_columns(data["code"].cast(pl.Utf8).alias("code")) # may remove after MEDS v0.3
data = data.with_columns(plain_predicate.MEDS_eval_expr().cast(PRED_CNT_TYPE).alias(name))
logger.info(f"Added predicate column '{name}'.")

Expand Down
4 changes: 4 additions & 0 deletions src/aces/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
logger.info("No static variable criteria specified, removing all rows with null timestamps...")
predicates_df = predicates_df.drop_nulls(subset=["subject_id", "timestamp"])

if predicates_df.is_empty():
logger.warning("No valid rows found after filtering patient demographics. Exiting.")
return pl.DataFrame()

logger.info("Identifying possible trigger nodes based on the specified trigger event...")
prospective_root_anchors = check_constraints({cfg.trigger.predicate: (1, None)}, predicates_df).select(
"subject_id", pl.col("timestamp").alias("subtree_anchor_timestamp")
Expand Down

0 comments on commit a1159b4

Please sign in to comment.