diff --git a/data/collect.py b/data/collect.py index 25afd02..b7df0ce 100644 --- a/data/collect.py +++ b/data/collect.py @@ -2,7 +2,7 @@ import json import os -from typing import List +from typing import Any, Dict, List import numpy as np import pandas as pd @@ -16,6 +16,30 @@ from tqdm import tqdm +def _save_to_csv( + data: List[Dict[str, Any]], columns: List[str], save_path: str +) -> None: + """Save the DataFrame to a csv file. + + Parameters + ---------- + data : List[Dict[str, Any]] + The data to save. + columns : List[str] + The column names of the data. + save_path : str + The path to save the data. + + """ + dataframe = pd.DataFrame(data, columns=columns) + dataframe.to_csv( + save_path, + mode="a", + header=(not os.path.exists(save_path)), + index=False, + ) + + class FHIRDataCollector: """Collect data from the FHIR database and save to csv files.""" @@ -76,22 +100,10 @@ def get_patient_data(self) -> None: } buffer.append(patient_data) if len(buffer) >= self.buffer_size: - df_buffer = pd.DataFrame(buffer, columns=patient_cols) buffer = [] - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, patient_cols, save_path) if buffer: - df_buffer = pd.DataFrame(buffer, columns=patient_cols) - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, patient_cols, save_path) def get_encounter_data(self) -> None: """Get encounter data from the database and save to a csv file.""" @@ -152,23 +164,11 @@ def get_encounter_data(self) -> None: } buffer.append(e_data) if len(buffer) >= self.buffer_size: - df_buffer = pd.DataFrame(buffer, columns=encounter_cols) buffer = [] - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, encounter_cols, save_path) if buffer: - df_buffer = pd.DataFrame(buffer, columns=encounter_cols) - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, encounter_cols, save_path) patients = patients[~patients["patient_id"].isin(outpatient_ids)] patients.to_csv(self.csv_dir + "/inpatient.csv", index=False) @@ -187,7 +187,6 @@ def get_procedure_data(self) -> None: autoload_with=self.engine, schema=self.schema, ) - procedure_cols = [ "patient_id", "length", @@ -208,7 +207,6 @@ def get_procedure_data(self) -> None: query = select(procedure_table.c.fhir).where( procedure_table.c.patient_id == patient_id, ) - results = connection.execute(query).fetchall() proc_codes = [] proc_dates = [] @@ -241,23 +239,11 @@ def get_procedure_data(self) -> None: } buffer.append(m_data) if len(buffer) >= self.buffer_size: - df_buffer = pd.DataFrame(buffer, columns=procedure_cols) buffer = [] - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, procedure_cols, save_path) if buffer: - df_buffer = pd.DataFrame(buffer, columns=procedure_cols) - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, procedure_cols, save_path) with open(self.vocab_dir + "/procedure_vocab.json", "w") as f: json.dump(list(procedure_vocab), f) @@ -346,23 +332,12 @@ def get_medication_data(self) -> None: } buffer.append(m_data) if len(buffer) >= self.buffer_size: - df_buffer = pd.DataFrame(buffer, columns=medication_cols) buffer = [] - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, medication_cols, save_path) if buffer: - df_buffer = pd.DataFrame(buffer, columns=medication_cols) - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, medication_cols, save_path) + with open(self.vocab_dir + "/med_vocab.json", "w") as f: json.dump(list(med_vocab), f) @@ -432,7 +407,7 @@ def get_lab_data(self) -> None: encounters.append(event.encounter.reference.split("/")[-1]) if code not in all_units: - all_units[code] = set(event.valueQuantity.unit) + all_units[code] = set([event.valueQuantity.unit]) else: all_units[code].add(event.valueQuantity.unit) @@ -452,23 +427,11 @@ def get_lab_data(self) -> None: } buffer.append(m_data) if len(buffer) >= self.buffer_size: - df_buffer = pd.DataFrame(buffer, columns=lab_cols) buffer = [] - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, lab_cols, save_path) if buffer: - df_buffer = pd.DataFrame(buffer, columns=lab_cols) - df_buffer.to_csv( - save_path, - mode="a", - header=(not os.path.exists(save_path)), - index=False, - ) + _save_to_csv(buffer, lab_cols, save_path) with open(self.vocab_dir + "/lab_vocab.json", "w") as f: json.dump(list(lab_vocab), f)