Skip to content

Commit

Permalink
Overriding predicates and demographics using separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
justin13601 committed Aug 30, 2024
1 parent a33efb6 commit 17318bb
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
)

overriding_predicates = {}
overriding_demographics = {}
if predicates_path:
if isinstance(predicates_path, str):
predicates_path = Path(predicates_path)
Expand All @@ -1233,7 +1234,8 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
# in the YAML
_ = predicates_dict.pop("description", None)
_ = predicates_dict.pop("metadata", None)
overriding_predicates = predicates_dict.pop("predicates")
overriding_predicates = predicates_dict.pop("predicates", {})
overriding_demographics = predicates_dict.pop("patient_demographics", {})

if predicates_dict:
raise ValueError(
Expand All @@ -1245,31 +1247,17 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
_ = loaded_dict.pop("description", None)
_ = loaded_dict.pop("metadata", None)

patient_demographics = loaded_dict.pop("patient_demographics", None)
trigger = loaded_dict.pop("trigger")
windows = loaded_dict.pop("windows", None)

predicates = loaded_dict.pop("predicates")
predicates = loaded_dict.pop("predicates", {})
patient_demographics = loaded_dict.pop("patient_demographics", {})

if loaded_dict:
raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'")

predicates_to_override = {}
for key, value in predicates.items():
if value.get("code") == "???" or value.get("expr") == "???":
predicates_to_override[key] = value

if predicates_to_override:
if not set(predicates_to_override.keys()).issubset(set(overriding_predicates.keys())):
raise ValueError(
"Unresolved predicates found in configuration file: "
f"{', '.join(predicates_to_override.keys())}. "
"Please provide a predicates file to resolve these predicates or specify them explicitly "
"in the task configuration file (with a code that is not '???')."
)

# update predicates with overrides only if the predicate is in the list of predicates to override
predicates.update({k: v for k, v in overriding_predicates.items() if k in predicates_to_override})
final_predicates = {**predicates, **overriding_predicates}
final_demographics = {**patient_demographics, **overriding_demographics}

logger.info("Parsing windows...")
if windows is None:
Expand All @@ -1288,15 +1276,17 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
current_predicates = set(referenced_predicates)
special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY}
for pred in current_predicates - special_predicates:
if pred not in predicates:
if pred not in final_predicates:
raise KeyError(
f"Something referenced predicate {pred} that wasn't defined in the configuration."
)
if "expr" in predicates[pred]:
referenced_predicates.update(DerivedPredicateConfig(**predicates[pred]).input_predicates)
if "expr" in final_predicates[pred]:
referenced_predicates.update(
DerivedPredicateConfig(**final_predicates[pred]).input_predicates
)

logger.info("Parsing predicates...")
predicates_to_parse = {k: v for k, v in predicates.items() if k in referenced_predicates}
predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates}
predicate_objs = {}
for n, p in predicates_to_parse.items():
if "expr" in p:
Expand All @@ -1306,13 +1296,14 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
other_cols = {k: v for k, v in p.items() if k not in config_data}
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)

if patient_demographics:
if final_demographics:
logger.info("Parsing patient demographics...")
patient_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items()
final_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in final_demographics.items()
}
predicate_objs.update(patient_demographics)
predicate_objs.update(final_demographics)

print(predicate_objs)
return cls(predicates=predicate_objs, trigger=trigger, windows=windows)

def _initialize_predicates(self):
Expand Down

0 comments on commit 17318bb

Please sign in to comment.