Skip to content

Commit

Permalink
adding processing table functions
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Jun 20, 2024
1 parent 9cd7f1a commit e9118e2
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
187 changes: 187 additions & 0 deletions eds_scikit/utils/process_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from datetime import timedelta
from typing import Dict, List, Union

import numpy as np
from loguru import logger

from eds_scikit.utils.checks import check_columns
from eds_scikit.utils.typing import DataFrame


def tag_table_by_type(
table: DataFrame,
type_groups: Union[str, Dict],
source_col: str,
target_col: str,
filter_table: bool = False,
):
"""Add tag column to table based on their value (ex : condition_occurrence -> "DIABETIC", "NOT DIABETIC)
Parameters
----------
table : DataFrame
Table (must contain columns source_col, target_col)
type_groups : Union[str, Dict]
Regex or Dict of regex to define tags and associated regex.
source_col : str
Column on which the tagging is applied.
target_col : str
Label column name
remove_other : bool
If True, remove untagged columns
Returns
-------
DataFrame
Input dataframe with tag column `target_col`
Output
-------
| person_id | condition_source_value | DIABETIC_CONDITION |
|:---------------------------:|-------------------------:|:---------------------:|
| 001 | E100 | DIABETES_TYPE_I |
| 002 | E101 | DIABETES_TYPE_I |
| 003 | E110 | DIABETES_TYPE_II |
| 004 | E113 | DIABETES_TYPE_II |
| 005 | A001 | OTHER |
"""
if isinstance(type_groups, str):
type_groups = {type_groups: type_groups}
table[target_col] = "OTHER"

for type_name, type_value in type_groups.items():

table.loc[
table[source_col]
.astype(str)
.str.contains(
type_value,
case=False,
regex=True,
na=False,
),
target_col,
] = type_name

logger.debug(
"The following {} : {} have been tagged on table.",
target_col,
type_groups,
)

table = table[table[target_col] != "OTHER"] if filter_table else table

return table


def tag_table_period_length(
table: DataFrame,
length_of_stays: List[float],
start_date_col: str = "visit_start_datetime",
end_date_col: str = "visit_end_datetime",
) -> DataFrame:
"""Tag table by length of stays (can be applied to visit_occurrence table)
Example : length_of_stays = [7, 14]
Output
-------
| person_id | visit_start_datetime | visit_end_datetime | length_of_stay |
|:---------------------------:|-------------------------:|:---------------------:|:---------------------:|
| 001 | 2020-04-01 | 2020-04-12 | "7 days - 14 days" |
| 002 | 2020-04-01 | 2020-04-03 | "<= 7 days " |
| 003 | 2020-04-01 | 2020-04-09 | ">= 7 days " |
Parameters
----------
table : DataFrame
length_of_stays : List[float]
Example : [7 , 14]
start_date_col : str, optional
by default "visit_start_datetime"
end_date_col : str, optional
by default "visit_end_datetime"
Returns
-------
DataFrame
"""
table = table.assign(
length=(table[end_date_col] - table[start_date_col])
/ np.timedelta64(timedelta(days=1))
)

# Incomplete stays
table = table.assign(length_of_stay="Not specified")
table["length_of_stay"] = table.length_of_stay.mask(
table[end_date_col].isna(),
"Incomplete stay",
)

# Complete stays
min_duration = length_of_stays[0]
max_duration = length_of_stays[-1]
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] <= min_duration),
"<= {} days".format(min_duration),
)
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] >= max_duration),
">= {} days".format(max_duration),
)
for min_length, max_length in zip(length_of_stays[:-1], length_of_stays[1:]):
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] >= min_length) & (table["length"] < max_length),
"{} days - {} days".format(min_length, max_length),
)
table = table.drop(columns="length")

return table


def tag_table_with_age(
table: DataFrame, date_col: str, person: DataFrame, age_ranges: List[int] = None
):
"""Tag table with person age
Parameters
----------
table : DataFrame
must contain person_id and date_col
date_column: str
date column from table on which to compute age
person : DataFrame
must contain person_id
age_ranges : List[int]
if None, simply compute age.
example : None, [18], [18, 60]
Returns
-------
DataFrame
"""
check_columns(df=person, required_columns=["person_id", "birth_datetime"])
check_columns(df=table, required_columns=[date_col, "person_id"])

table = table.merge(person[["person_id", "birth_datetime"]], on="person_id")

table["age"] = (table[date_col] - table["birth_datetime"]) / (
np.timedelta64(timedelta(days=1)) * 356
)
table["age"] = table["age"].astype(int)

table["age_range"] = "Not specified"
if age_ranges:
age_ranges.sort()
table.loc[table.age <= age_ranges[0], "age_range"] = f"age <= {age_ranges[0]}"

for age_min, age_max in zip(age_ranges[:-1], age_ranges[1:]):
in_range = (table.age > age_min) & (table.age <= age_max)
table.loc[in_range, "age_range"] = f"{age_min} < age <= {age_max}"

table.loc[table.age > age_ranges[-1], "age_range"] = f"age > {age_ranges[-1]}"

return table
63 changes: 63 additions & 0 deletions tests/test_process_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pandas as pd
import numpy as np
import pytest

from eds_scikit.utils import framework
from eds_scikit.utils.process_table import tag_table_by_type, tag_table_period_length, tag_table_with_age

# Generate random data for the first dataframe
num_rows = 1000
table = {
"condition_source_value": ["E100", "E101", "E110", "A001", "B002"],
"visit_start_datetime": ["2021-05-16", "2018-08-16", "2023-03-14", "2023-05-09", "2022-07-17"],
"visit_end_datetime": ["2021-05-26", "2018-09-16", "2023-03-15", "2023-10-10", "2022-07-18"],
"person_id": [0, 1, 2, 3, 4]
}

table = pd.DataFrame(table)
table["visit_start_datetime"] = pd.to_datetime(table["visit_start_datetime"] )
table["visit_end_datetime"] = pd.to_datetime(table["visit_end_datetime"] )

# Generate random data for the second dataframe
person = {
"person_id": [0, 1, 2, 3, 4],
"birth_datetime": ["2000-03-29", "1990-04-08", "1975-09-28", "1970-04-28", "1975-10-03"]
}
person["birth_datetime"] = pd.to_datetime(person["birth_datetime"] )

person = pd.DataFrame(person)

@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_tag_table_with_age(module):

person_fr = framework.to(module, person)
table_fr = framework.to(module, table)

table_with_age = tag_table_with_age(table_fr, "visit_start_datetime", person_fr, age_ranges=[24, 30, 40])
table_with_age = framework.to("pandas", table_with_age)
assert (table_with_age["age_range"] == pd.Series(["age <= 24", "24 < age <= 30", "age > 40", "age > 40", "age > 40"], name="age_range")).all()

table_with_age = tag_table_with_age(table_fr, "visit_start_datetime", person_fr, age_ranges=None)
table_with_age = framework.to("pandas", table_with_age)
assert (table_with_age["age"] == pd.Series([21, 29, 48, 54, 48], name="age")).all()

@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_table_by_type(module):

table_fr = framework.to(module, table)

table_by_type = tag_table_by_type(table_fr, type_groups={"DIABETES_TYPE_I" : r"^E10", "DIABETES_TYPE_II" : r"^E11"}, source_col="condition_source_value", target_col="tag")
table_by_type = framework.to("pandas", table_by_type)
assert (table_by_type["tag"] == pd.Series(["DIABETES_TYPE_I", "DIABETES_TYPE_I", "DIABETES_TYPE_II", "OTHER", "OTHER"], name="tag")).all()
table_by_type = tag_table_by_type(table_fr, type_groups={"DIABETES_TYPE_I" : r"^E10", "DIABETES_TYPE_II" : r"^E11"}, source_col="condition_source_value", target_col="tag", filter_table=True)
table_by_type = framework.to("pandas", table_by_type)
assert (table_by_type["tag"] == pd.Series(["DIABETES_TYPE_I", "DIABETES_TYPE_I", "DIABETES_TYPE_II"], name="tag")).all()

@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_tag_table_period_length(module):

table_fr = framework.to(module, table)

table_period_length = tag_table_period_length(table_fr, length_of_stays=[7, 14])
table_period_length = framework.to("pandas", table_period_length)
assert (table_period_length["length_of_stay"] == pd.Series(["7 days - 14 days", ">= 14 days", "<= 7 days", ">= 14 days", "<= 7 days"], name="tag")).all()

0 comments on commit e9118e2

Please sign in to comment.