diff --git a/src/aces/config.py b/src/aces/config.py index ba45181..f1af2f9 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -75,7 +75,7 @@ def MEDS_eval_expr(self) -> pl.Expr: return pl.all_horizontal(criteria) def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: - """Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset. + """Returns a Polars expression that evaluates this predicate for a ESGPT formatted dataset. Note: The output syntax for the following examples is dependent on the polars version used. The expected outputs have been validated on polars version 0.20.30. @@ -101,11 +101,19 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: >>> expr = PlainPredicateConfig("event_type//ADMISSION").ESGPT_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE col("event_type").strict_cast(String).str.split([String(&)]).list.contains([String(ADMISSION)]) + >>> expr = PlainPredicateConfig("BP//diastolic//atrial").ESGPT_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("BP")) == (String(diastolic//atrial))] """ code_is_in_parts = "//" in self.code if code_is_in_parts: - measurement_name, code = self.code.split("//") + codes = self.code.split("//") + measurement_name = codes.pop(0) + if len(codes) > 1: + code = "//".join(codes) + else: + code = codes[0] if measurement_name.lower() == "event_type": criteria = [pl.col("event_type").cast(pl.Utf8).str.split("&").list.contains(code)] else: