Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #308 from Aarhus-Psychiatry-Research/martbern/save…
Browse files Browse the repository at this point in the history
…_eval_dataset_as_parquet

Add exclusion criteria to eval
  • Loading branch information
MartinBernstorff authored Nov 1, 2022
2 parents a95af0f + 1b471ce commit ff5dd16
Show file tree
Hide file tree
Showing 9 changed files with 1,068 additions and 1,042 deletions.
3 changes: 2 additions & 1 deletion src/psycopt2d/config/data/synth_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ data:
outcome_timestamp: timestamp_outcome
id: citizen_ids
age: pred_age
exclusion_timestamp: timestamp_exclusion
custom:
n_hba1c: hba1c_within_9999_days_count_nan

# Looking ahead
drop_patient_if_outcome_before_date: null
drop_patient_if_exclusion_before_date: 1971-01-01

# Looking behind
lookbehind_combination: [30, 60, 100]
Expand Down
5 changes: 3 additions & 2 deletions src/psycopt2d/config/data/t2d_parquet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
suffix: parquet

# Patient exclusion criteria
drop_patient_if_outcome_before_date: 2013-01-01
drop_patient_if_exclusion_before_date: 2013-01-01

# Prediction time exclusion criteria
min_prediction_time_date: 2013-01-01
Expand All @@ -18,9 +18,10 @@ data:

col_name:
pred_timestamp: timestamp
outcome_timestamp: _timestamp_first_t2d
outcome_timestamp: timestamp_first_t2d_hba1c
id: dw_ek_borger
age: pred_age_in_years
exclusion_timestamp: timestamp_exclusion
custom:
n_hba1c: hba1c_within_9999_days_count_fallback_0

Expand Down
3 changes: 2 additions & 1 deletion src/psycopt2d/evaluation_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class EvalDataset(BaseModel):
y: pd.Series
y_hat_probs: pd.Series
y_hat_int: pd.Series
age: Optional[pd.Series]
age: Optional[pd.Series] = None
exclusion_timestamps: Optional[pd.Series] = None
custom: Optional[CustomColumns] = CustomColumns(n_hba1c=None)

def __init__(self, **kwargs):
Expand Down
35 changes: 25 additions & 10 deletions src/psycopt2d/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self,
cfg: FullConfigSchema,
):
self.cfg = cfg
self.cfg: FullConfigSchema = cfg

# File handling
self.dir_path = Path(cfg.data.dir)
Expand Down Expand Up @@ -160,21 +160,24 @@ def _drop_rows_if_datasets_ends_within_days(

return dataset

def drop_patient_if_outcome_before_date(
def _drop_patient_if_excluded(
self,
dataset: pd.DataFrame,
) -> pd.DataFrame:
"""Drop patients within washin period."""
"""Drop patients that have an exclusion event within the washin
period."""

n_rows_before_modification = dataset.shape[0]

outcome_before_date = (
dataset[self.cfg.data.col_name.outcome_timestamp]
< self.cfg.data.drop_patient_if_outcome_before_date
dataset[self.cfg.data.col_name.exclusion_timestamp]
< self.cfg.data.drop_patient_if_exclusion_before_date
)

patients_to_drop = set(dataset["dw_ek_borger"][outcome_before_date].unique())
dataset = dataset[~dataset["dw_ek_borger"].isin(patients_to_drop)]
patients_to_drop = set(
dataset[self.cfg.data.col_name.id][outcome_before_date].unique(),
)
dataset = dataset[~dataset[self.cfg.data.col_name.id].isin(patients_to_drop)]

n_rows_after_modification = dataset.shape[0]

Expand All @@ -187,6 +190,8 @@ def drop_patient_if_outcome_before_date(
msg.info(
f"Dropped {n_rows_before_modification - n_rows_after_modification} ({percent_dropped}%) rows because patients had diabetes in the washin period.",
)
else:
msg.info("No patients met exclusion criteria. Didn't drop any.")

return dataset

Expand Down Expand Up @@ -385,10 +390,18 @@ def _keep_unique_outcome_col_with_lookahead_days_matching_conf(

return df

def n_outcome_col_names(self, df: pd.DataFrame):
def n_outcome_col_names(self, df: pd.DataFrame) -> int:
"""How many outcome columns there are in a dataframe."""
return len(infer_outcome_col_name(df=df, allow_multiple=True))

def _drop_rows_after_event_time(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Drop all rows where prediction timestamp is < outcome timestamp."""

return dataset[
dataset[self.cfg.data.col_name.pred_timestamp]
< dataset[self.cfg.data.col_name.outcome_timestamp]
]

def _process_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Process dataset, namely:
Expand All @@ -402,8 +415,10 @@ def _process_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""
dataset = self.convert_timestamp_dtype_and_nat(dataset)

if self.cfg.data.drop_patient_if_outcome_before_date:
dataset = self.drop_patient_if_outcome_before_date(dataset=dataset)
dataset = self._drop_rows_after_event_time(dataset=dataset)

if self.cfg.data.drop_patient_if_exclusion_before_date:
dataset = self._drop_patient_if_excluded(dataset=dataset)

# Drop if later than min prediction time date
if self.cfg.data.min_prediction_time_date:
Expand Down
49 changes: 24 additions & 25 deletions src/psycopt2d/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,28 @@ def stratified_cross_validation( # pylint: disable=too-many-locals
return train_df


def create_eval_dataset(cfg: FullConfigSchema, outcome_col_name: str, df: pd.DataFrame):
"""Create an evaluation dataset object from a dataframe and
FullConfigSchema."""

eval_dataset = EvalDataset(
ids=df[cfg.data.col_name.id],
y=df[outcome_col_name],
y_hat_probs=df["y_hat_prob"],
y_hat_int=df["y_hat_prob"].round(),
pred_timestamps=df[cfg.data.col_name.pred_timestamp],
outcome_timestamps=df[cfg.data.col_name.outcome_timestamp],
age=df[cfg.data.col_name.age],
exclusion_timestamps=df[cfg.data.col_name.exclusion_timestamp],
)

if cfg.data.col_name.custom:
if cfg.data.col_name.custom.n_hba1c:
eval_dataset.custom.n_hba1c = df[cfg.data.col_name.custom.n_hba1c]

return eval_dataset


def train_and_eval_on_crossvalidation(
cfg: FullConfigSchema,
train: pd.DataFrame,
Expand Down Expand Up @@ -194,20 +216,7 @@ def train_and_eval_on_crossvalidation(

df.rename(columns={"oof_y_hat": "y_hat_prob"}, inplace=True)

eval_dataset = EvalDataset(
ids=df[cfg.data.col_name.id],
y=df[outcome_col_name],
y_hat_probs=df["y_hat_prob"],
y_hat_int=df["y_hat_prob"].round(),
pred_timestamps=df[cfg.data.col_name.pred_timestamp],
outcome_timestamps=df[cfg.data.col_name.outcome_timestamp],
age=df[cfg.data.col_name.age],
)

if cfg.data.col_name.custom.n_hba1c:
eval_dataset.custom.n_hba1c = df[cfg.data.col_name.custom.n_hba1c]

return eval_dataset
return create_eval_dataset(cfg=cfg, outcome_col_name=outcome_col_name, df=df)


def train_and_eval_on_val_split(
Expand Down Expand Up @@ -249,17 +258,7 @@ def train_and_eval_on_val_split(
df = val
df["y_hat_prob"] = y_val_hat_prob

eval_dataset = EvalDataset(
ids=df[cfg.data.col_name.id],
y=df[outcome_col_name],
y_hat_probs=df["y_hat_prob"],
y_hat_int=df["y_hat_prob"].round(),
pred_timestamps=df[cfg.data.col_name.pred_timestamp],
outcome_timestamps=df[cfg.data.col_name.outcome_timestamp],
age=df[cfg.data.col_name.age],
)

return eval_dataset
return create_eval_dataset(cfg=cfg, outcome_col_name=outcome_col_name, df=df)


def train_and_get_model_eval_df(
Expand Down
4 changes: 3 additions & 1 deletion src/psycopt2d/utils/config_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ColumnNames(BaseModel):
outcome_timestamp: str # Column name for outcome timestamps
id: str # Citizen colnames
age: str # Name of the age column
exclusion_timestamp: str # Name of the exclusion timestamps column.
# Drops all visits whose pred_timestamp <= exclusion_timestamp.

custom: Optional[CustomColNames] = None
# Column names that are custom to the given prediction problem.
Expand All @@ -101,7 +103,7 @@ class DataConf(BaseModel):
min_lookahead_days: int
# Drop all prediction times where (max timestamp in the dataset) - (current timestamp) is less than min_lookahead_days

drop_patient_if_outcome_before_date: Optional[Union[str, datetime]]
drop_patient_if_exclusion_before_date: Optional[Union[str, datetime]]
# Drop all visits from a patient if the outcome is before this date. If None, no patients are dropped.

min_prediction_time_date: Optional[Union[str, datetime]]
Expand Down
Loading

0 comments on commit ff5dd16

Please sign in to comment.