Skip to content

Commit

Permalink
Fix break with // in esgpt
Browse files Browse the repository at this point in the history
  • Loading branch information
justin13601 committed Jul 3, 2024
1 parent e068b97 commit a2488c5
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def MEDS_eval_expr(self) -> pl.Expr:
return pl.all_horizontal(criteria)

def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
"""Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset.
"""Returns a Polars expression that evaluates this predicate for a ESGPT formatted dataset.
Note: The output syntax for the following examples is dependent on the polars version used. The
expected outputs have been validated on polars version 0.20.30.
Expand All @@ -101,11 +101,19 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr:
>>> expr = PlainPredicateConfig("event_type//ADMISSION").ESGPT_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
col("event_type").strict_cast(String).str.split([String(&)]).list.contains([String(ADMISSION)])
>>> expr = PlainPredicateConfig("BP//diastolic//atrial").ESGPT_eval_expr()
>>> print(expr) # doctest: +NORMALIZE_WHITESPACE
[(col("BP")) == (String(diastolic//atrial))]
"""
code_is_in_parts = "//" in self.code

if code_is_in_parts:
measurement_name, code = self.code.split("//")
codes = self.code.split("//")
measurement_name = codes.pop(0)
if len(codes) > 1:
code = "//".join(codes)
else:
code = codes[0]
if measurement_name.lower() == "event_type":
criteria = [pl.col("event_type").cast(pl.Utf8).str.split("&").list.contains(code)]
else:
Expand Down

0 comments on commit a2488c5

Please sign in to comment.