Skip to content
This repository has been archived by the owner on May 3, 2023. It is now read-only.

Commit

Permalink
Merge pull request #119 from Aarhus-Psychiatry-Research/martbern/supp…
Browse files Browse the repository at this point in the history
…ort_filtering_by_quarantine

Support filtering by quarantine dates
  • Loading branch information
MartinBernstorff authored Dec 15, 2022
2 parents 7674e4a + 4c8538e commit 73237db
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 49 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pyarrow = ">=9.0.0,<9.1.0"
psycopmlutils = ">=0.2.4, <0.3.0"
protobuf = "<=3.20.3" # Other versions give errors with pytest, super weird!
frozendict = "^2.3.4"
timeseriesflattener = "0.20.2"
timeseriesflattener = "0.21.0"

[tool.poetry.dev-dependencies]
pytest = ">=7.1.3, <7.1.5"
Expand Down
83 changes: 46 additions & 37 deletions src/application/t2d/modules/specify_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging

import numpy as np

from psycop_feature_generation.application_modules.project_setup import ProjectInfo
from timeseriesflattener.feature_spec_objects import (
BaseModel,
OutcomeGroupSpec,
Expand All @@ -14,6 +12,8 @@
_AnySpec,
)

from psycop_feature_generation.application_modules.project_setup import ProjectInfo

from .loaders.t2d_loaders import ( # noqa pylint: disable=unused-import
timestamp_exclusion,
)
Expand All @@ -31,7 +31,7 @@ class SpecSet(BaseModel):


def get_static_predictor_specs(project_info: ProjectInfo):
"""Get static predictor specs"""
"""Get static predictor specs."""
return [
StaticSpec(
values_loader="sex_female",
Expand All @@ -42,7 +42,7 @@ def get_static_predictor_specs(project_info: ProjectInfo):


def get_metadata_specs(project_info: ProjectInfo) -> list[_AnySpec]:
"""Get metadata specs"""
"""Get metadata specs."""
log.info("–––––––– Generating metadata specs ––––––––")

return [
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_metadata_specs(project_info: ProjectInfo) -> list[_AnySpec]:


def get_outcome_specs(project_info: ProjectInfo):
"""Get outcome specs"""
"""Get outcome specs."""
log.info("–––––––– Generating outcome specs ––––––––")

return OutcomeGroupSpec(
Expand All @@ -84,39 +84,8 @@ def get_outcome_specs(project_info: ProjectInfo):
).create_combinations()


def get_temporal_predictor_specs(project_info: ProjectInfo) -> list[PredictorSpec]:
"""Generate predictor spec list."""
log.info("–––––––– Generating temporal predictor specs ––––––––")

resolve_multiple = ["max", "min", "mean", "latest", "count"]
interval_days = [30, 90, 180, 365, 730]
allowed_nan_value_prop = [0]

lab_results = get_lab_result_specs(
resolve_multiple, interval_days, allowed_nan_value_prop
)

diagnoses = get_diagnoses_specs(
resolve_multiple, interval_days, allowed_nan_value_prop
)

medications = get_medication_specs(
resolve_multiple, interval_days, allowed_nan_value_prop
)

demographics = PredictorGroupSpec(
values_loader=["weight_in_kg", "height_in_cm", "bmi"],
lookbehind_days=interval_days,
resolve_multiple_fn=["latest"],
fallback=[np.nan],
allowed_nan_value_prop=allowed_nan_value_prop,
prefix=project_info.prefix.predictor,
).create_combinations()

return lab_results + medications + diagnoses + demographics


def get_medication_specs(resolve_multiple, interval_days, allowed_nan_value_prop):
"""Get medication specs."""
log.info("–––––––– Generating medication specs ––––––––")

psychiatric_medications = PredictorGroupSpec(
Expand Down Expand Up @@ -158,6 +127,7 @@ def get_medication_specs(resolve_multiple, interval_days, allowed_nan_value_prop


def get_diagnoses_specs(resolve_multiple, interval_days, allowed_nan_value_prop):
"""Get diagnoses specs."""
log.info("–––––––– Generating diagnoses specs ––––––––")

lifestyle_diagnoses = PredictorGroupSpec(
Expand Down Expand Up @@ -197,6 +167,7 @@ def get_diagnoses_specs(resolve_multiple, interval_days, allowed_nan_value_prop)


def get_lab_result_specs(resolve_multiple, interval_days, allowed_nan_value_prop):
"""Get lab result specs."""
log.info("–––––––– Generating lab result specs ––––––––")

general_lab_results = PredictorGroupSpec(
Expand Down Expand Up @@ -231,6 +202,44 @@ def get_lab_result_specs(resolve_multiple, interval_days, allowed_nan_value_prop
return general_lab_results + diabetes_lab_results


def get_temporal_predictor_specs(project_info: ProjectInfo) -> list[PredictorSpec]:
"""Generate predictor spec list."""
log.info("–––––––– Generating temporal predictor specs ––––––––")

resolve_multiple = ["max", "min", "mean", "latest", "count"]
interval_days = [30, 90, 180, 365, 730]
allowed_nan_value_prop = [0]

lab_results = get_lab_result_specs(
resolve_multiple,
interval_days,
allowed_nan_value_prop,
)

diagnoses = get_diagnoses_specs(
resolve_multiple,
interval_days,
allowed_nan_value_prop,
)

medications = get_medication_specs(
resolve_multiple,
interval_days,
allowed_nan_value_prop,
)

demographics = PredictorGroupSpec(
values_loader=["weight_in_kg", "height_in_cm", "bmi"],
lookbehind_days=interval_days,
resolve_multiple_fn=["latest"],
fallback=[np.nan],
allowed_nan_value_prop=allowed_nan_value_prop,
prefix=project_info.prefix.predictor,
).create_combinations()

return lab_results + medications + diagnoses + demographics


def get_feature_specs(project_info: ProjectInfo) -> list[_AnySpec]:
"""Get a spec set."""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Class for filtering prediction times before they are used for feature
generation."""
from typing import Optional

import pandas as pd


class PredictionTimeFilterer:
"""Class for filtering prediction times before they are used for
feature."""

def __init__(
self,
prediction_times_df: pd.DataFrame,
entity_id_col_name: str,
quarantine_timestamps_df: Optional[pd.DataFrame] = None,
quarantine_interval_days: Optional[int] = None,
):
"""Initialize PredictionTimeFilterer.
Args:
prediction_times_df (pd.DataFrame): Prediction times dataframe.
Should contain entity_id and timestamp columns with col_names matching those in project_info.col_names.
quarantine_df (pd.DataFrame, optional): A dataframe with "timestamp" column from which to start the quarantine.
Any prediction times within the quarantine_interval_days after this timestamp will be dropped.
quarantine_days (int, optional): Number of days to quarantine.
entity_id_col_name (str): Name of the entity_id_col_name column.
"""

self.prediction_times_df = prediction_times_df
self.quarantine_df = quarantine_timestamps_df
self.quarantine_days = quarantine_interval_days
self.entity_id_col_name = entity_id_col_name

def _filter_prediction_times_by_quarantine_period(self):
# We need to check if ANY quarantine date hits each prediction time.
# Create combinations
df = self.prediction_times_df.merge(
self.quarantine_df,
on=self.entity_id_col_name,
how="left",
suffixes=("_pred", "_quarantine"),
)

df["days_since_quarantine"] = (
df["timestamp_pred"] - df["timestamp_quarantine"]
).dt.days

# Check if the prediction time is hit by the quarantine date.
df.loc[
(df["days_since_quarantine"] < self.quarantine_days)
& (df["days_since_quarantine"] > 0),
"hit_by_quarantine",
] = True

# If any of the combinations for a UUID is hit by a quarantine date, drop it.
df["hit_by_quarantine"] = df.groupby("pred_time_uuid")[
"hit_by_quarantine"
].transform("max")

df = df.loc[
df["hit_by_quarantine"] != True # pylint: disable=singleton-comparison
]

df = df.drop_duplicates(subset="pred_time_uuid")

# Drop the columns we added
df = df.drop(
columns=[
"days_since_quarantine",
"hit_by_quarantine",
"timestamp_quarantine",
],
)

# Rename the timestamp column
df = df.rename(columns={"timestamp_pred": "timestamp"})

return df

def filter(self):
"""Run filters based on the provided parameters."""
df = self.prediction_times_df

if self.quarantine_df is not None or self.quarantine_days is not None:
if all([v is None for v in (self.quarantine_days, self.quarantine_df)]):
raise ValueError(
"If either of quarantine_df and quarantine_days are provided, both must be provided.",
)

df = self._filter_prediction_times_by_quarantine_period()

return df
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
"""Flatten the dataset."""
from typing import Optional

import pandas as pd
import psutil
from timeseriesflattener.feature_cache.cache_to_disk import DiskCache
from timeseriesflattener.feature_spec_objects import _AnySpec
from timeseriesflattener.flattened_dataset import TimeseriesFlattener

from psycop_feature_generation.application_modules.filter_prediction_times import (
PredictionTimeFilterer,
)
from psycop_feature_generation.application_modules.project_setup import ProjectInfo
from psycop_feature_generation.application_modules.utils import print_df_dimensions_diff
from psycop_feature_generation.application_modules.wandb_utils import (
wandb_alert_on_exception,
)
from psycop_feature_generation.loaders.raw.load_demographic import birthdays
from timeseriesflattener.feature_cache.cache_to_disk import DiskCache
from timeseriesflattener.feature_spec_objects import _AnySpec
from timeseriesflattener.flattened_dataset import TimeseriesFlattener


@print_df_dimensions_diff
def filter_prediction_times(
prediction_times_df: pd.DataFrame,
project_info: ProjectInfo,
quarantine_df: Optional[pd.DataFrame] = None,
quarantine_days: Optional[int] = None,
) -> pd.DataFrame:
"""Filter prediction times.
Args:
prediction_times_df (pd.DataFrame): Prediction times dataframe.
Should contain entity_id and timestamp columns with col_names matching those in project_info.col_names.
project_info (ProjectInfo): Project info.
quarantine_df (pd.DataFrame, optional): Quarantine dataframe with "timestamp" and "project_info.col_names.id" columns.
quarantine_days (int, optional): Number of days to quarantine.
Returns:
pd.DataFrame: Filtered prediction times dataframe.
"""
filterer = PredictionTimeFilterer(
prediction_times_df=prediction_times_df,
entity_id_col_name=project_info.col_names.id,
quarantine_timestamps_df=quarantine_df,
quarantine_interval_days=quarantine_days,
)

return filterer.filter()


@wandb_alert_on_exception
Expand All @@ -18,6 +53,8 @@ def create_flattened_dataset(
prediction_times_df: pd.DataFrame,
drop_pred_times_with_insufficient_look_distance: bool,
project_info: ProjectInfo,
quarantine_df: Optional[pd.DataFrame] = None,
quarantine_days: Optional[int] = None,
) -> pd.DataFrame:
"""Create flattened dataset.
Expand All @@ -28,13 +65,21 @@ def create_flattened_dataset(
Should contain entity_id and timestamp columns with col_names matching those in project_info.col_names.
drop_pred_times_with_insufficient_look_distance (bool): Whether to drop prediction times with insufficient look distance.
See timeseriesflattener tutorial for more info.
quarantine_df (pd.DataFrame, optional): Quarantine dataframe with "timestamp" and "project_info.col_names.id" columns.
quarantine_days (int, optional): Number of days to quarantine.
Returns:
FlattenedDataset: Flattened dataset.
"""
filtered_prediction_times_df = filter_prediction_times(
prediction_times_df=prediction_times_df,
project_info=project_info,
quarantine_df=quarantine_df,
quarantine_days=quarantine_days,
)

flattened_dataset = TimeseriesFlattener(
prediction_times_df=prediction_times_df,
prediction_times_df=filtered_prediction_times_df,
n_workers=min(
len(feature_specs),
psutil.cpu_count(logical=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Literal

import wandb

from psycop_feature_generation.utils import RELATIVE_PROJECT_ROOT, SHARED_RESOURCES_PATH
from timeseriesflattener.feature_spec_objects import ( # pylint: disable=no-name-in-module
BaseModel,
)

from psycop_feature_generation.utils import RELATIVE_PROJECT_ROOT, SHARED_RESOURCES_PATH

log = logging.getLogger(__name__)


Expand Down
Loading

0 comments on commit 73237db

Please sign in to comment.