Skip to content

Commit

Permalink
Merge pull request #10 from VectorInstitute/update_scripts
Browse files Browse the repository at this point in the history
Update scripts
  • Loading branch information
amrit110 authored Mar 11, 2024
2 parents 199c612 + 2117c37 commit 54d829e
Show file tree
Hide file tree
Showing 18 changed files with 1,360 additions and 1,258 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.2.2'
rev: 'v0.3.1'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
6 changes: 3 additions & 3 deletions data/bigbird_data/DataChecker.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@
" # Combining and shuffling patient IDs\n",
" finetune_patients = pos_patients_ids + neg_patients_ids\n",
" random.shuffle(finetune_patients)\n",
" patient_ids_dict[\"valid\"][\"few_shot\"][\n",
" f\"{each_finetune_size}_patients\"\n",
" ] = finetune_patients\n",
" patient_ids_dict[\"valid\"][\"few_shot\"][f\"{each_finetune_size}_patients\"] = (\n",
" finetune_patients\n",
" )\n",
"\n",
" # Performing stratified k-fold split\n",
" skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)\n",
Expand Down
114 changes: 58 additions & 56 deletions data/collect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Collect data from the FHIR database and save to csv files."""

import json
import os
from typing import List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -29,7 +32,12 @@ def __init__(
self.save_dir = save_dir
self.buffer_size = buffer_size

self.vocab_dir = os.path.join(self.save_dir, "vocab")
self.csv_dir = os.path.join(self.save_dir, "csv_files")

os.makedirs(self.save_dir, exist_ok=True)
os.makedirs(self.vocab_dir, exist_ok=True)
os.makedirs(self.csv_dir, exist_ok=True)

def get_patient_data(self) -> None:
"""Get patient data from the database and save to a csv file."""
Expand All @@ -47,7 +55,7 @@ def get_patient_data(self) -> None:
"deceasedBoolean",
"deceasedDateTime",
]
save_path = os.path.join(self.save_dir, "patients.csv")
save_path = os.path.join(self.csv_dir, "patients.csv")
buffer = []

with self.engine.connect() as connection:
Expand Down Expand Up @@ -88,7 +96,7 @@ def get_patient_data(self) -> None:
def get_encounter_data(self) -> None:
"""Get encounter data from the database and save to a csv file."""
try:
patients = pd.read_csv(self.save_dir + "/patients.csv")
patients = pd.read_csv(self.csv_dir + "/patients.csv")
except FileNotFoundError:
print("Patients file not found. Please run get_patient_data() first.")
return
Expand All @@ -101,7 +109,7 @@ def get_encounter_data(self) -> None:
)

encounter_cols = ["patient_id", "length", "encounter_ids", "starts", "ends"]
save_path = os.path.join(self.save_dir, "encounters.csv")
save_path = os.path.join(self.csv_dir, "encounters.csv")
buffer = []
outpatient_ids = []

Expand Down Expand Up @@ -163,12 +171,12 @@ def get_encounter_data(self) -> None:
)

patients = patients[~patients["patient_id"].isin(outpatient_ids)]
patients.to_csv(self.save_dir + "/inpatient.csv", index=False)
patients.to_csv(self.csv_dir + "/inpatient.csv", index=False)

def get_procedure_data(self) -> None:
"""Get procedure data from the database and save to a csv file."""
try:
patients = pd.read_csv(self.save_dir + "/inpatient.csv")
patients = pd.read_csv(self.csv_dir + "/inpatient.csv")
except FileNotFoundError:
print("Patients file not found. Please run get_encounter_data() first.")
return
Expand All @@ -187,7 +195,7 @@ def get_procedure_data(self) -> None:
"proc_dates",
"encounter_ids",
]
save_path = os.path.join(self.save_dir, "procedures.csv")
save_path = os.path.join(self.csv_dir, "procedures.csv")
procedure_vocab = set()
buffer = []

Expand Down Expand Up @@ -251,13 +259,13 @@ def get_procedure_data(self) -> None:
index=False,
)

with open(self.save_dir + "/procedure_vocab.json", "w") as f:
with open(self.vocab_dir + "/procedure_vocab.json", "w") as f:
json.dump(list(procedure_vocab), f)

def get_medication_data(self) -> None:
"""Get medication data from the database and save to a csv file."""
try:
patients = pd.read_csv(self.save_dir + "/inpatient.csv")
patients = pd.read_csv(self.csv_dir + "/inpatient.csv")
except FileNotFoundError:
print("Patients file not found. Please run get_encounter_data() first.")
return
Expand All @@ -282,7 +290,7 @@ def get_medication_data(self) -> None:
"med_dates",
"encounter_ids",
]
save_path = os.path.join(self.save_dir, "med_requests.csv")
save_path = os.path.join(self.csv_dir, "med_requests.csv")
med_vocab = set()
buffer = []

Expand All @@ -303,36 +311,27 @@ def get_medication_data(self) -> None:
data = row[0]
if "medicationCodeableConcept" in data:
# includes messy text data, so we skip it
# should not be of type list
continue
# if isinstance(data['medicationCodeableConcept'], list):
# data['medicationCodeableConcept'] = \
# data['medicationCodeableConcept'][0]
# med_req = MedicationRequest(data)
# if med_req.authoredOn is None or med_req.encounter is None:
# continue
# med_codes.append(med_req.medicationCodeableConcept.coding[0].code)
# med_dates.append(med_req.authoredOn.isostring)
# encounters.append(med_req.encounter.reference.split('/')[-1])
else:
med_req = MedicationRequest(data)
if med_req.authoredOn is None or med_req.encounter is None:
med_req = MedicationRequest(data)
if med_req.authoredOn is None or med_req.encounter is None:
continue
med_query = select(medication_table.c.fhir).where(
medication_table.c.id
== med_req.medicationReference.reference.split("/")[-1],
)
med_result = connection.execute(med_query).fetchone()
med_result = Medication(med_result[0]) if med_result else None
if med_result is not None:
code = med_result.code.coding[0].code
if not code.isdigit():
continue
med_query = select(medication_table.c.fhir).where(
medication_table.c.id
== med_req.medicationReference.reference.split("/")[-1],
med_vocab.add(code)
med_codes.append(code)
med_dates.append(med_req.authoredOn.isostring)
encounters.append(
med_req.encounter.reference.split("/")[-1],
)
med_result = connection.execute(med_query).fetchone()
med_result = Medication(med_result[0]) if med_result else None
if med_result is not None:
code = med_result.code.coding[0].code
if not code.isdigit():
continue
med_vocab.add(code)
med_codes.append(code)
med_dates.append(med_req.authoredOn.isostring)
encounters.append(
med_req.encounter.reference.split("/")[-1],
)

assert len(med_codes) == len(
med_dates,
Expand Down Expand Up @@ -364,13 +363,13 @@ def get_medication_data(self) -> None:
header=(not os.path.exists(save_path)),
index=False,
)
with open(self.save_dir + "/med_vocab.json", "w") as f:
with open(self.vocab_dir + "/med_vocab.json", "w") as f:
json.dump(list(med_vocab), f)

def get_lab_data(self) -> None:
"""Get lab data from the database and save to a csv file."""
try:
patients = pd.read_csv(self.save_dir + "/inpatient.csv")
patients = pd.read_csv(self.csv_dir + "/inpatient.csv")
except FileNotFoundError:
print("Patients file not found. Please run get_encounter_data() first.")
return
Expand All @@ -391,7 +390,7 @@ def get_lab_data(self) -> None:
"lab_dates",
"encounter_ids",
]
save_path = os.path.join(self.save_dir, "labs.csv")
save_path = os.path.join(self.csv_dir, "labs.csv")
lab_vocab = set()
all_units = {}
buffer = []
Expand Down Expand Up @@ -433,7 +432,7 @@ def get_lab_data(self) -> None:
encounters.append(event.encounter.reference.split("/")[-1])

if code not in all_units:
all_units[code] = set([event.valueQuantity.unit])
all_units[code] = set(event.valueQuantity.unit)
else:
all_units[code].add(event.valueQuantity.unit)

Expand Down Expand Up @@ -471,21 +470,23 @@ def get_lab_data(self) -> None:
index=False,
)

with open(self.save_dir + "/lab_vocab.json", "w") as f:
with open(self.vocab_dir + "/lab_vocab.json", "w") as f:
json.dump(list(lab_vocab), f)

all_units = {k: list(v) for k, v in all_units.items()}
with open(self.save_dir + "/lab_units.json", "w") as f:
with open(self.vocab_dir + "/lab_units.json", "w") as f:
json.dump(all_units, f)

def filter_lab_data(
self,
) -> None:
"""Filter out lab codes that have more than one units."""
try:
labs = pd.read_csv(self.save_dir + "/labs.csv")
lab_vocab = json.load(open(self.save_dir + "/lab_vocab.json", "r"))
lab_units = json.load(open(self.save_dir + "/lab_units.json", "r"))
labs = pd.read_csv(self.csv_dir + "/labs.csv")
with open(self.vocab_dir + "/lab_vocab.json", "r") as f:
lab_vocab = json.load(f)
with open(self.vocab_dir + "/lab_units.json", "r") as f:
lab_units = json.load(f)
except FileNotFoundError:
print("Labs file not found. Please run get_lab_data() first.")
return
Expand All @@ -494,7 +495,7 @@ def filter_lab_data(
if len(units) > 1:
lab_vocab.remove(code)

def filter_codes(row, vocab):
def filter_codes(row: pd.Series, vocab: List[str]) -> pd.Series:
for col in [
"lab_codes",
"lab_values",
Expand All @@ -519,11 +520,11 @@ def filter_codes(row, vocab):

labs = labs.apply(lambda x: filter_codes(x, lab_vocab), axis=1)

labs.to_csv(self.save_dir + "/filtered_labs.csv", index=False)
with open(self.save_dir + "/lab_vocab.json", "w") as f:
labs.to_csv(self.csv_dir + "/filtered_labs.csv", index=False)
with open(self.vocab_dir + "/lab_vocab.json", "w") as f:
json.dump(list(lab_vocab), f)

def process_lab_values(self, num_bins: int = 5):
def process_lab_values(self, num_bins: int = 5) -> None:
"""Bin lab values into discrete values.
Parameters
Expand All @@ -532,18 +533,19 @@ def process_lab_values(self, num_bins: int = 5):
number of bins, by default 5
"""
try:
labs = pd.read_csv(self.save_dir + "/filtered_labs.csv")
lab_vocab = json.load(open(self.save_dir + "/lab_vocab.json", "r"))
labs = pd.read_csv(self.csv_dir + "/filtered_labs.csv")
with open(self.vocab_dir + "/lab_vocab.json", "r") as f:
lab_vocab = json.load(f)
except FileNotFoundError:
print("Labs file not found. Please run get_lab_data() first.")
return

def apply_eval(row):
def apply_eval(row: pd.Series) -> pd.Series:
for col in ["lab_codes", "lab_values"]:
row[col] = eval(row[col])
return row

def assign_to_quantile_bins(row):
def assign_to_quantile_bins(row: pd.Series) -> pd.Series:
if row["length"] == 0:
row["binned_values"] = []
return row
Expand Down Expand Up @@ -571,21 +573,21 @@ def assign_to_quantile_bins(row):
).categories

labs = labs.apply(assign_to_quantile_bins, axis=1)
labs.to_csv(self.save_dir + "/processed_labs.csv", index=False)
labs.to_csv(self.csv_dir + "/processed_labs.csv", index=False)

lab_vocab_binned = []
lab_vocab_binned.extend(
[f"{code}_{i}" for code in lab_vocab for i in range(num_bins)],
)
with open(self.save_dir + "/lab_vocab.json", "w") as f:
with open(self.vocab_dir + "/lab_vocab.json", "w") as f:
json.dump(lab_vocab_binned, f)


if __name__ == "__main__":
collector = FHIRDataCollector(
db_path="postgresql://postgres:pwd@localhost:5432/mimiciv-2.0",
schema="mimic_fhir",
save_dir="/mnt/data/odyssey/mimiciv_fhir2",
save_dir="/mnt/data/odyssey/mimiciv_fhir",
buffer_size=10000,
)
collector.get_patient_data()
Expand Down
Loading

0 comments on commit 54d829e

Please sign in to comment.