diff --git a/.gitignore b/.gitignore index 1ffd545..4572735 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3939878..2438da2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: check-toml - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.2.2' + rev: 'v0.3.1' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/data/bigbird_data/DataChecker.ipynb b/data/bigbird_data/DataChecker.ipynb index 15ac403..f0e8cc0 100644 --- a/data/bigbird_data/DataChecker.ipynb +++ b/data/bigbird_data/DataChecker.ipynb @@ -554,9 +554,9 @@ " # Combining and shuffling patient IDs\n", " finetune_patients = pos_patients_ids + neg_patients_ids\n", " random.shuffle(finetune_patients)\n", - " patient_ids_dict[\"valid\"][\"few_shot\"][\n", - " f\"{each_finetune_size}_patients\"\n", - " ] = finetune_patients\n", + " patient_ids_dict[\"valid\"][\"few_shot\"][f\"{each_finetune_size}_patients\"] = (\n", + " finetune_patients\n", + " )\n", "\n", " # Performing stratified k-fold split\n", " skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=23)\n", diff --git a/data/collect.py b/data/collect.py index e12f994..25afd02 100644 --- a/data/collect.py +++ b/data/collect.py @@ -1,5 +1,8 @@ +"""Collect data from the FHIR database and save to csv files.""" + import json import os +from typing import List import numpy as np import pandas as pd @@ -29,7 +32,12 @@ def __init__( self.save_dir = save_dir self.buffer_size = buffer_size + self.vocab_dir = os.path.join(self.save_dir, "vocab") + self.csv_dir = os.path.join(self.save_dir, "csv_files") + os.makedirs(self.save_dir, exist_ok=True) + os.makedirs(self.vocab_dir, exist_ok=True) + os.makedirs(self.csv_dir, exist_ok=True) def get_patient_data(self) -> None: """Get patient data from the database and save to a csv file.""" @@ -47,7 +55,7 @@ def get_patient_data(self) -> None: "deceasedBoolean", "deceasedDateTime", ] - save_path = os.path.join(self.save_dir, "patients.csv") + save_path = os.path.join(self.csv_dir, "patients.csv") buffer = [] with self.engine.connect() as connection: @@ -88,7 +96,7 @@ def get_patient_data(self) -> None: def get_encounter_data(self) -> None: """Get encounter data from the database and save to a csv file.""" try: - patients = pd.read_csv(self.save_dir + "/patients.csv") + patients = pd.read_csv(self.csv_dir + "/patients.csv") except FileNotFoundError: print("Patients file not found. Please run get_patient_data() first.") return @@ -101,7 +109,7 @@ def get_encounter_data(self) -> None: ) encounter_cols = ["patient_id", "length", "encounter_ids", "starts", "ends"] - save_path = os.path.join(self.save_dir, "encounters.csv") + save_path = os.path.join(self.csv_dir, "encounters.csv") buffer = [] outpatient_ids = [] @@ -163,12 +171,12 @@ def get_encounter_data(self) -> None: ) patients = patients[~patients["patient_id"].isin(outpatient_ids)] - patients.to_csv(self.save_dir + "/inpatient.csv", index=False) + patients.to_csv(self.csv_dir + "/inpatient.csv", index=False) def get_procedure_data(self) -> None: """Get procedure data from the database and save to a csv file.""" try: - patients = pd.read_csv(self.save_dir + "/inpatient.csv") + patients = pd.read_csv(self.csv_dir + "/inpatient.csv") except FileNotFoundError: print("Patients file not found. Please run get_encounter_data() first.") return @@ -187,7 +195,7 @@ def get_procedure_data(self) -> None: "proc_dates", "encounter_ids", ] - save_path = os.path.join(self.save_dir, "procedures.csv") + save_path = os.path.join(self.csv_dir, "procedures.csv") procedure_vocab = set() buffer = [] @@ -251,13 +259,13 @@ def get_procedure_data(self) -> None: index=False, ) - with open(self.save_dir + "/procedure_vocab.json", "w") as f: + with open(self.vocab_dir + "/procedure_vocab.json", "w") as f: json.dump(list(procedure_vocab), f) def get_medication_data(self) -> None: """Get medication data from the database and save to a csv file.""" try: - patients = pd.read_csv(self.save_dir + "/inpatient.csv") + patients = pd.read_csv(self.csv_dir + "/inpatient.csv") except FileNotFoundError: print("Patients file not found. Please run get_encounter_data() first.") return @@ -282,7 +290,7 @@ def get_medication_data(self) -> None: "med_dates", "encounter_ids", ] - save_path = os.path.join(self.save_dir, "med_requests.csv") + save_path = os.path.join(self.csv_dir, "med_requests.csv") med_vocab = set() buffer = [] @@ -303,36 +311,27 @@ def get_medication_data(self) -> None: data = row[0] if "medicationCodeableConcept" in data: # includes messy text data, so we skip it + # should not be of type list continue - # if isinstance(data['medicationCodeableConcept'], list): - # data['medicationCodeableConcept'] = \ - # data['medicationCodeableConcept'][0] - # med_req = MedicationRequest(data) - # if med_req.authoredOn is None or med_req.encounter is None: - # continue - # med_codes.append(med_req.medicationCodeableConcept.coding[0].code) - # med_dates.append(med_req.authoredOn.isostring) - # encounters.append(med_req.encounter.reference.split('/')[-1]) - else: - med_req = MedicationRequest(data) - if med_req.authoredOn is None or med_req.encounter is None: + med_req = MedicationRequest(data) + if med_req.authoredOn is None or med_req.encounter is None: + continue + med_query = select(medication_table.c.fhir).where( + medication_table.c.id + == med_req.medicationReference.reference.split("/")[-1], + ) + med_result = connection.execute(med_query).fetchone() + med_result = Medication(med_result[0]) if med_result else None + if med_result is not None: + code = med_result.code.coding[0].code + if not code.isdigit(): continue - med_query = select(medication_table.c.fhir).where( - medication_table.c.id - == med_req.medicationReference.reference.split("/")[-1], + med_vocab.add(code) + med_codes.append(code) + med_dates.append(med_req.authoredOn.isostring) + encounters.append( + med_req.encounter.reference.split("/")[-1], ) - med_result = connection.execute(med_query).fetchone() - med_result = Medication(med_result[0]) if med_result else None - if med_result is not None: - code = med_result.code.coding[0].code - if not code.isdigit(): - continue - med_vocab.add(code) - med_codes.append(code) - med_dates.append(med_req.authoredOn.isostring) - encounters.append( - med_req.encounter.reference.split("/")[-1], - ) assert len(med_codes) == len( med_dates, @@ -364,13 +363,13 @@ def get_medication_data(self) -> None: header=(not os.path.exists(save_path)), index=False, ) - with open(self.save_dir + "/med_vocab.json", "w") as f: + with open(self.vocab_dir + "/med_vocab.json", "w") as f: json.dump(list(med_vocab), f) def get_lab_data(self) -> None: """Get lab data from the database and save to a csv file.""" try: - patients = pd.read_csv(self.save_dir + "/inpatient.csv") + patients = pd.read_csv(self.csv_dir + "/inpatient.csv") except FileNotFoundError: print("Patients file not found. Please run get_encounter_data() first.") return @@ -391,7 +390,7 @@ def get_lab_data(self) -> None: "lab_dates", "encounter_ids", ] - save_path = os.path.join(self.save_dir, "labs.csv") + save_path = os.path.join(self.csv_dir, "labs.csv") lab_vocab = set() all_units = {} buffer = [] @@ -433,7 +432,7 @@ def get_lab_data(self) -> None: encounters.append(event.encounter.reference.split("/")[-1]) if code not in all_units: - all_units[code] = set([event.valueQuantity.unit]) + all_units[code] = set(event.valueQuantity.unit) else: all_units[code].add(event.valueQuantity.unit) @@ -471,11 +470,11 @@ def get_lab_data(self) -> None: index=False, ) - with open(self.save_dir + "/lab_vocab.json", "w") as f: + with open(self.vocab_dir + "/lab_vocab.json", "w") as f: json.dump(list(lab_vocab), f) all_units = {k: list(v) for k, v in all_units.items()} - with open(self.save_dir + "/lab_units.json", "w") as f: + with open(self.vocab_dir + "/lab_units.json", "w") as f: json.dump(all_units, f) def filter_lab_data( @@ -483,9 +482,11 @@ def filter_lab_data( ) -> None: """Filter out lab codes that have more than one units.""" try: - labs = pd.read_csv(self.save_dir + "/labs.csv") - lab_vocab = json.load(open(self.save_dir + "/lab_vocab.json", "r")) - lab_units = json.load(open(self.save_dir + "/lab_units.json", "r")) + labs = pd.read_csv(self.csv_dir + "/labs.csv") + with open(self.vocab_dir + "/lab_vocab.json", "r") as f: + lab_vocab = json.load(f) + with open(self.vocab_dir + "/lab_units.json", "r") as f: + lab_units = json.load(f) except FileNotFoundError: print("Labs file not found. Please run get_lab_data() first.") return @@ -494,7 +495,7 @@ def filter_lab_data( if len(units) > 1: lab_vocab.remove(code) - def filter_codes(row, vocab): + def filter_codes(row: pd.Series, vocab: List[str]) -> pd.Series: for col in [ "lab_codes", "lab_values", @@ -519,11 +520,11 @@ def filter_codes(row, vocab): labs = labs.apply(lambda x: filter_codes(x, lab_vocab), axis=1) - labs.to_csv(self.save_dir + "/filtered_labs.csv", index=False) - with open(self.save_dir + "/lab_vocab.json", "w") as f: + labs.to_csv(self.csv_dir + "/filtered_labs.csv", index=False) + with open(self.vocab_dir + "/lab_vocab.json", "w") as f: json.dump(list(lab_vocab), f) - def process_lab_values(self, num_bins: int = 5): + def process_lab_values(self, num_bins: int = 5) -> None: """Bin lab values into discrete values. Parameters @@ -532,18 +533,19 @@ def process_lab_values(self, num_bins: int = 5): number of bins, by default 5 """ try: - labs = pd.read_csv(self.save_dir + "/filtered_labs.csv") - lab_vocab = json.load(open(self.save_dir + "/lab_vocab.json", "r")) + labs = pd.read_csv(self.csv_dir + "/filtered_labs.csv") + with open(self.vocab_dir + "/lab_vocab.json", "r") as f: + lab_vocab = json.load(f) except FileNotFoundError: print("Labs file not found. Please run get_lab_data() first.") return - def apply_eval(row): + def apply_eval(row: pd.Series) -> pd.Series: for col in ["lab_codes", "lab_values"]: row[col] = eval(row[col]) return row - def assign_to_quantile_bins(row): + def assign_to_quantile_bins(row: pd.Series) -> pd.Series: if row["length"] == 0: row["binned_values"] = [] return row @@ -571,13 +573,13 @@ def assign_to_quantile_bins(row): ).categories labs = labs.apply(assign_to_quantile_bins, axis=1) - labs.to_csv(self.save_dir + "/processed_labs.csv", index=False) + labs.to_csv(self.csv_dir + "/processed_labs.csv", index=False) lab_vocab_binned = [] lab_vocab_binned.extend( [f"{code}_{i}" for code in lab_vocab for i in range(num_bins)], ) - with open(self.save_dir + "/lab_vocab.json", "w") as f: + with open(self.vocab_dir + "/lab_vocab.json", "w") as f: json.dump(lab_vocab_binned, f) @@ -585,7 +587,7 @@ def assign_to_quantile_bins(row): collector = FHIRDataCollector( db_path="postgresql://postgres:pwd@localhost:5432/mimiciv-2.0", schema="mimic_fhir", - save_dir="/mnt/data/odyssey/mimiciv_fhir2", + save_dir="/mnt/data/odyssey/mimiciv_fhir", buffer_size=10000, ) collector.get_patient_data() diff --git a/data/sequence.py b/data/sequence.py index 684e6ef..cec49bd 100644 --- a/data/sequence.py +++ b/data/sequence.py @@ -1,5 +1,8 @@ +"""Create patient sequences from the events dataframes.""" + import os from datetime import datetime +from typing import Dict, List, Union import numpy as np import pandas as pd @@ -7,16 +10,17 @@ class SequenceGenerator: - """Generates patient sequences from the events dataframes.s""" + """Generate patient sequences from the events dataframes.""" def __init__( self, - max_seq_length: int = 2048, + max_seq_length: int, pad_token: str = "[PAD]", mask_token: str = "[MASK]", start_token: str = "[VS]", end_token: str = "[VE]", class_token: str = "[CLS]", + register_token: str = "[REG]", unknown_token: str = "[UNK]", reference_time: str = "2020-01-01 00:00:00", data_dir: str = "data_files", @@ -28,17 +32,18 @@ def __init__( self.start_token = start_token self.end_token = end_token self.class_token = class_token + self.register_token = register_token self.unknown_token = unknown_token self.reference_time = parser.parse(reference_time) self.data_dir = data_dir self.save_dir = save_dir - self.after_death_events = [] + self.after_death_events: List[str] = [] os.makedirs(save_dir, exist_ok=True) @property - def time_delta_tokens(self) -> list: - """Gets the time delta tokens.""" + def time_delta_tokens(self) -> List[str]: + """Get the time delta tokens.""" return ( [f"[W_{i}]" for i in range(0, 4)] + [f"[M_{i}]" for i in range(0, 13)] @@ -46,19 +51,21 @@ def time_delta_tokens(self) -> list: ) @property - def special_tokens(self) -> list: - """Gets the special tokens.""" + def special_tokens(self) -> List[str]: + """Get the special tokens.""" return [ self.pad_token, self.mask_token, self.start_token, self.end_token, self.class_token, + self.register_token, self.unknown_token, ] + self.time_delta_tokens @property - def get_token_type_dict(self) -> dict: + def get_token_type_dict(self) -> Dict[str, int]: + """Get the token type dictionary.""" return { "pad": 0, "class": 1, @@ -68,47 +75,18 @@ def get_token_type_dict(self) -> dict: "lab": 5, "med": 6, "proc": 7, + "reg": 8, } - def _load_patients(self) -> pd.DataFrame: - """Loads the patients dataframe.""" - patients = pd.read_csv(os.path.join(self.data_dir, "inpatient.csv"), nrows=1000) - return patients - - def _load_encounters(self) -> pd.DataFrame: - """Loads the encounters dataframe.""" - encounters = pd.read_csv( - os.path.join(self.data_dir, "encounters.csv"), - nrows=1000, - ) - return encounters - - def _load_procedures(self) -> pd.DataFrame: - """Loads the procedures dataframe.""" - procedures = pd.read_csv( - os.path.join(self.data_dir, "procedures.csv"), - nrows=1000, - ) - return procedures - - def _load_medications(self) -> pd.DataFrame: - """Loads the medications dataframe.""" - medications = pd.read_csv( - os.path.join(self.data_dir, "med_requests.csv"), - nrows=1000, - ) - return medications - - def _load_labs(self) -> pd.DataFrame: - """Loads the labs dataframe.""" - labs = pd.read_csv( - os.path.join(self.data_dir, "processed_labs.csv"), - nrows=1000, + @staticmethod + def to_list(input_value: Union[np.ndarray, List]) -> List: + """Convert the input value to a list if it is an instance of numpy array.""" + return ( + input_value.tolist() if isinstance(input_value, np.ndarray) else input_value ) - return labs def _sort_encounters(self, encounter_row: pd.Series) -> pd.Series: - """Sorts the encounters by start time. + """Sort the encounters by start time. Parameters ---------- @@ -164,7 +142,7 @@ def _get_encounters_age( return encounter_row def _get_encounters_time(self, encounter_row: pd.Series) -> pd.Series: - """Gets the time of the encounters in weeks with respect to a reference start time. + """Get the time of the encounters in weeks with respect to a reference time. Parameters ---------- @@ -184,7 +162,7 @@ def _get_encounters_time(self, encounter_row: pd.Series) -> pd.Series: return encounter_row def _calculate_intervals(self, encounter_row: pd.Series) -> pd.Series: - """Calculates the intervals between encounters. + """Calculate the intervals between encounters. Parameters ---------- @@ -198,8 +176,8 @@ def _calculate_intervals(self, encounter_row: pd.Series) -> pd.Series: """ start_times = encounter_row["starts"] end_times = encounter_row["ends"] - intervals = {} - eq_encounters = {} + intervals: Dict[str, str] = {} + eq_encounters: Dict[str, List[str]] = {} for i in range(len(start_times) - 1): start = parser.parse(start_times[i + 1]) start_id = encounter_row["encounter_ids"][i + 1] @@ -243,8 +221,9 @@ def _edit_datetimes( encounter_row: pd.Series, concept_name: str, ) -> pd.Series: - """Edits the datetimes of the events so that they won't fall - out of the corresponding encounter time frame. + """Edit the datetimes of the events. + + Done so that they won't fall out of the corresponding encounter time frame. Parameters ---------- @@ -287,13 +266,15 @@ def _edit_datetimes( encounter_start = encounter_row["starts"][encounter_index] encounter_end = encounter_row["ends"][encounter_index] start_parsed = parser.parse(encounter_start) - start_parsed = parser.parse(encounter_end) + end_parsed = parser.parse(encounter_end) date_parsed = parser.parse(date) + + enc_date = date if date_parsed < start_parsed: - date = encounter_start - elif date_parsed > start_parsed: - date = encounter_end - dates.append(date) + enc_date = encounter_start + elif date_parsed > end_parsed: + enc_date = encounter_end + dates.append(enc_date) row[date_column] = dates return row @@ -304,7 +285,7 @@ def _concat_concepts( labs: pd.DataFrame, encounters: pd.DataFrame, ) -> pd.DataFrame: - """Concatenates the events of different concepts. + """Concatenate the events of different concepts. Parameters ---------- @@ -377,8 +358,13 @@ def _concat_concepts( return procedures - def _add_tokens(self, row: pd.Series, encounter_row: pd.Series) -> pd.Series: - """Adds tokens to the events. + # pylint: disable=too-many-statements + def _add_tokens( + self, + row: pd.Series, + encounter_row: pd.Series, + ) -> pd.Series: + """Add tokens to the events. Parameters ---------- @@ -409,6 +395,7 @@ def _add_tokens(self, row: pd.Series, encounter_row: pd.Series) -> pd.Series: eq_encounters = encounter_row["eq_encounters"] age_mapping = dict(zip(ecounters, encounter_row["ages"])) time_mapping = dict(zip(ecounters, encounter_row["times"])) + event_tokens = [self.class_token] type_tokens = [self.get_token_type_dict["class"]] age_tokens = [0] @@ -417,7 +404,7 @@ def _add_tokens(self, row: pd.Series, encounter_row: pd.Series) -> pd.Series: position_tokens = [0] segment_value = 1 - position_value = 1 + position_value = 0 prev_encounter = None @@ -428,14 +415,22 @@ def _add_tokens(self, row: pd.Series, encounter_row: pd.Series) -> pd.Series: ): if ( event_encounter != prev_encounter - and event_encounter not in eq_encounters.keys() + and event_encounter not in eq_encounters ): if prev_encounter is not None: # Adding Visit End Token event_tokens.append(self.end_token) type_tokens.append(self.get_token_type_dict["end"]) - age_tokens.append(age_mapping[event_encounter]) - time_tokens.append(time_mapping[event_encounter]) + age_tokens.append(age_mapping[prev_encounter]) + time_tokens.append(time_mapping[prev_encounter]) + visit_segments.append(segment_value) + position_tokens.append(position_value) + + # Adding Register Token + event_tokens.append(self.register_token) + type_tokens.append(self.get_token_type_dict["reg"]) + age_tokens.append(age_mapping[prev_encounter]) + time_tokens.append(time_mapping[prev_encounter]) visit_segments.append(segment_value) position_tokens.append(position_value) @@ -492,7 +487,7 @@ def _add_tokens(self, row: pd.Series, encounter_row: pd.Series) -> pd.Series: row["time_tokens"] = time_tokens row["visit_tokens"] = visit_segments row["position_tokens"] = position_tokens - row["num_visits"] = len(set(position_tokens)) + row["num_visits"] = len(set(position_tokens)) - 1 return row def _get_mortality_label( @@ -501,7 +496,7 @@ def _get_mortality_label( patient_row: pd.Series, encounter_row: pd.Series, ) -> pd.Series: - """Gets the mortality label for the patient. + """Get the mortality label for the patient. Parameters ---------- @@ -541,8 +536,12 @@ def _get_mortality_label( row["death_after_end"] = death_end.days return row - def _truncate_or_pad(self, row: pd.Series) -> pd.Series: - """Truncates or pads the sequence to max_seq_length. + def _truncate_or_pad( + self, + row: pd.Series, + pad_events: bool = False, + ) -> pd.Series: + """Truncate or pads the sequence to max_seq_length. Parameters ---------- @@ -554,17 +553,23 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: pd.Series Updated row with truncated or padded sequence """ - sequence = row["event_tokens"] - type = row["type_tokens"] - age = row["age_tokens"] - time = row["time_tokens"] - visit = row["visit_tokens"] - position = row["position_tokens"] + sequence = self.to_list(row["event_tokens"]) + t_type = self.to_list(row["type_tokens"]) + age = self.to_list(row["age_tokens"]) + time = self.to_list(row["time_tokens"]) + visit = self.to_list(row["visit_tokens"]) + position = self.to_list(row["position_tokens"]) seq_length = row["token_length"] truncated = False + padded_length = 0 if seq_length == self.max_seq_length: - row["event_tokens"] = sequence + row[f"event_tokens_{self.max_seq_length}"] = sequence + row[f"type_tokens_{self.max_seq_length}"] = t_type + row[f"age_tokens_{self.max_seq_length}"] = age + row[f"time_tokens_{self.max_seq_length}"] = time + row[f"visit_tokens_{self.max_seq_length}"] = visit + row[f"position_tokens_{self.max_seq_length}"] = position return row if seq_length > self.max_seq_length: @@ -574,6 +579,8 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: end_index = int(seq_length) if sequence[start_index] == self.end_token: + start_index += 3 + elif sequence[start_index] == self.register_token: start_index += 2 elif sequence[start_index].startswith(("[W_", "[M_", "[LT")): start_index += 1 @@ -585,7 +592,7 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: new_type = [ self.get_token_type_dict["class"], self.get_token_type_dict["start"], - ] + type[start_index:end_index] + ] + t_type[start_index:end_index] new_age = [0, age[start_index]] + age[start_index:end_index] new_time = [0, time[start_index]] + time[start_index:end_index] new_visit = [0, visit[start_index]] + visit[start_index:end_index] @@ -594,7 +601,7 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: ] else: new_sequence = [self.class_token] + sequence[start_index:end_index] - new_type = [self.get_token_type_dict["class"]] + type[ + new_type = [self.get_token_type_dict["class"]] + t_type[ start_index:end_index ] new_age = [0] + age[start_index:end_index] @@ -615,10 +622,17 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: if seq_length < self.max_seq_length: padded_length = int(max(0, self.max_seq_length - seq_length)) if truncated: - row[f"event_tokens_{self.max_seq_length}"] = ( - row[f"event_tokens_{self.max_seq_length}"] - + [self.pad_token] * padded_length - ) + if pad_events: + row[f"event_tokens_{self.max_seq_length}"] = ( + row[f"event_tokens_{self.max_seq_length}"] + + [self.pad_token] * padded_length + ) + else: + # padding will be done in the tokenizer + row[f"event_tokens_{self.max_seq_length}"] = row[ + f"event_tokens_{self.max_seq_length}" + ] + row[f"type_tokens_{self.max_seq_length}"] = ( row[f"type_tokens_{self.max_seq_length}"] + [self.get_token_type_dict["pad"]] * padded_length @@ -637,28 +651,25 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: + [self.max_seq_length + 1] * padded_length ) else: - row[f"event_tokens_{self.max_seq_length}"] = ( - row["event_tokens"] + [self.pad_token] * padded_length - ) + if pad_events: + row[f"event_tokens_{self.max_seq_length}"] = ( + sequence + [self.pad_token] * padded_length + ) + else: + # padding will be done in the tokenizer + row[f"event_tokens_{self.max_seq_length}"] = sequence + row[f"type_tokens_{self.max_seq_length}"] = ( - row["type_tokens"] - + [self.get_token_type_dict["pad"]] * padded_length - ) - row[f"age_tokens_{self.max_seq_length}"] = ( - row["age_tokens"] + [0] * padded_length - ) - row[f"time_tokens_{self.max_seq_length}"] = ( - row["time_tokens"] + [0] * padded_length - ) - row[f"visit_tokens_{self.max_seq_length}"] = ( - row["visit_tokens"] + [0] * padded_length + t_type + [self.get_token_type_dict["pad"]] * padded_length ) + row[f"age_tokens_{self.max_seq_length}"] = age + [0] * padded_length + row[f"time_tokens_{self.max_seq_length}"] = time + [0] * padded_length + row[f"visit_tokens_{self.max_seq_length}"] = visit + [0] * padded_length row[f"position_tokens_{self.max_seq_length}"] = ( - row["position_tokens"] + [self.max_seq_length + 1] * padded_length + position + [self.max_seq_length + 1] * padded_length ) for key in [ - f"event_tokens_{self.max_seq_length}", f"type_tokens_{self.max_seq_length}", f"age_tokens_{self.max_seq_length}", f"time_tokens_{self.max_seq_length}", @@ -672,8 +683,15 @@ def _truncate_or_pad(self, row: pd.Series) -> pd.Series: return row - def create_patient_sequence(self) -> None: - """Creates patient sequences and saves them as a parquet file.""" + def create_patient_sequence( + self, + chunksize: int = None, + min_events: int = 0, + min_visits: int = 0, + pad_events: bool = False, + all_columns: bool = True, + ) -> None: + """Create patient sequences and saves them as a parquet file.""" file_paths = [ f"{self.data_dir}/inpatient.csv", f"{self.data_dir}/encounters.csv", @@ -682,17 +700,13 @@ def create_patient_sequence(self) -> None: f"{self.data_dir}/processed_labs.csv", ] rounds = 0 - readers = [pd.read_csv(path, chunksize=10000) for path in file_paths] + readers = [pd.read_csv(path, chunksize=chunksize) for path in file_paths] while True: try: - # patients = self._load_patients() - # encounters = self._load_encounters() - # procedures = self._load_procedures() - # medications = self._load_medications() - # labs = self._load_labs() - # process encounters + # read dataframes dataframes = [next(reader).reset_index(drop=True) for reader in readers] patients, encounters, procedures, medications, labs = dataframes + # process encounters encounters = encounters.apply(self._sort_encounters, axis=1) encounters = encounters.apply( lambda row: self._get_encounters_age(row, patients.iloc[row.name]), @@ -735,11 +749,20 @@ def create_patient_sequence(self) -> None: labs, encounters, ) - combined_events = combined_events[combined_events["length"] > 0] + # filter patients based on min_events + combined_events = combined_events[ + combined_events["length"] > min_events + ] + # add special tokens to the events combined_events = combined_events.apply( lambda row: self._add_tokens(row, encounters.iloc[row.name]), axis=1, ) + # filter patients based on min_visits + combined_events = combined_events[ + combined_events["num_visits"] > min_visits + ] + # get mortality label combined_events = combined_events.apply( lambda row: self._get_mortality_label( row, @@ -752,10 +775,104 @@ def create_patient_sequence(self) -> None: ~combined_events["patient_id"].isin(self.after_death_events) ] combined_events = combined_events.apply( - lambda row: self._truncate_or_pad(row), + lambda row: self._truncate_or_pad(row, pad_events=pad_events), axis=1, ) + if not all_columns: + output_columns = [ + "patient_id", + "num_visits", + "deceased", + "death_after_start", + "death_after_end", + "length", + "token_length", + f"event_tokens_{self.max_seq_length}", + f"type_tokens_{self.max_seq_length}", + f"age_tokens_{self.max_seq_length}", + f"time_tokens_{self.max_seq_length}", + f"visit_tokens_{self.max_seq_length}", + f"position_tokens_{self.max_seq_length}", + ] + else: + output_columns = [ + "patient_id", + "num_visits", + "deceased", + "death_after_start", + "death_after_end", + "length", + "token_length", + "event_tokens", + "type_tokens", + "age_tokens", + "time_tokens", + "visit_tokens", + "position_tokens", + f"event_tokens_{self.max_seq_length}", + f"type_tokens_{self.max_seq_length}", + f"age_tokens_{self.max_seq_length}", + f"time_tokens_{self.max_seq_length}", + f"visit_tokens_{self.max_seq_length}", + f"position_tokens_{self.max_seq_length}", + ] + combined_events = combined_events[output_columns] + combined_events = combined_events.dropna( + subset=[f"event_tokens_{self.max_seq_length}"], + ) + combined_events.to_parquet( + self.save_dir + + f"/patient_sequences_{self.max_seq_length}_{rounds}.parquet", + engine="pyarrow", + ) + print(f"Round {rounds} done") + rounds += 1 + except StopIteration: + break + + def reapply_truncation( + self, + file_paths: Union[str, List[str]], + all_columns: bool = False, + ) -> None: + """ + Reapply truncation to Parquet file(s). + + Parameters + ---------- + file_paths : Union[str, List[str]] + Path or list of paths to Parquet files to be processed. + Returns + ------- + None + """ + if isinstance(file_paths, str): + file_paths = [file_paths] + for i, file_path in enumerate(sorted(file_paths)): + df = pd.read_parquet(file_path) + df = df.apply( + lambda row: self._truncate_or_pad(row, pad_events=False), + axis=1, + ) + + if not all_columns: + output_columns = [ + "patient_id", + "num_visits", + "deceased", + "death_after_start", + "death_after_end", + "length", + "token_length", + f"event_tokens_{self.max_seq_length}", + f"type_tokens_{self.max_seq_length}", + f"age_tokens_{self.max_seq_length}", + f"time_tokens_{self.max_seq_length}", + f"visit_tokens_{self.max_seq_length}", + f"position_tokens_{self.max_seq_length}", + ] + else: output_columns = [ "patient_id", "num_visits", @@ -777,23 +894,25 @@ def create_patient_sequence(self) -> None: f"visit_tokens_{self.max_seq_length}", f"position_tokens_{self.max_seq_length}", ] - combined_events = combined_events[output_columns] - combined_events = combined_events.dropna( - subset=[f"event_tokens_{self.max_seq_length}"], - ) - combined_events.to_parquet( - self.save_dir + f"/patient_sequences_{rounds}.parquet", - engine="pyarrow", - ) - print(f"Round {rounds} done") - rounds += 1 - except StopIteration: - break + df = df[output_columns] + base_name = f"patient_sequences_{self.max_seq_length}" + suffix = f"_{i}" if len(file_paths) > 1 else "" + file_name = f"{base_name}{suffix}.parquet" + df.to_parquet( + os.path.join(self.save_dir, file_name), + engine="pyarrow", + ) if __name__ == "__main__": generator = SequenceGenerator( - data_dir="/mnt/data/odyssey/mimiciv_fhir2", - save_dir="/mnt/data/odyssey/mimiciv_fhir2/parquets", + max_seq_length=512, + data_dir="/mnt/data/odyssey/mimiciv_fhir/csv_files", + save_dir="/mnt/data/odyssey/mimiciv_fhir/parquet_files", + ) + generator.create_patient_sequence( + chunksize=10000, + min_events=10, + min_visits=0, + all_columns=True, ) - generator.create_patient_sequence() diff --git a/finetune.py b/finetune.py index 240d6f4..a62af00 100644 --- a/finetune.py +++ b/finetune.py @@ -1,9 +1,10 @@ +"""Finetune the pre-trained model.""" + import argparse import os -from os.path import join +import sys +from typing import Any, Dict -import numpy as np -import pandas as pd import pytorch_lightning as pl import torch from lightning.pytorch.loggers import WandbLogger @@ -16,47 +17,62 @@ from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader -from models.cehr_bert.data import FinetuneDataset +from lib.data import FinetuneDataset +from lib.tokenizer import ConceptTokenizer +from lib.utils import ( + get_latest_checkpoint, + get_run_id, + load_config, + load_finetune_data, + seed_everything, +) +from models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain from models.cehr_bert.model import BertFinetune, BertPretrain -from models.cehr_bert.tokenizer import ConceptTokenizer - -def main(args): - torch.manual_seed(args.seed) - np.random.seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - pl.seed_everything(args.seed) +def main( + args: Dict[str, Any], + pre_model_config: Dict[str, Any], + fine_model_config: Dict[str, Any], +) -> None: + """Train the model.""" + # Setup environment + seed_everything(args.seed) os.environ["CUDA_LAUNCH_BLOCKING"] = "1" torch.cuda.empty_cache() torch.set_float32_matmul_precision("medium") - fine_tune = pd.read_parquet(join(args.data_dir, "fine_tune.parquet")) - fine_test = pd.read_parquet(join(args.data_dir, "fine_test.parquet")) - # fine_data = pd.read_parquet(join(args.data_dir, "fine_tune.parquet")) - # fine_train, fine_valtest = train_test_split( - # fine_data, - # train_size=args.train_size, - # random_state=args.seed, - # stratify=fine_data["label"], - # ) + # Load data + fine_tune, fine_test = load_finetune_data( + args.data_dir, + args.sequence_file, + args.id_file, + args.valid_scheme, + args.num_finetune_patients, + ) + + fine_tune.rename(columns={args.label_name: "label"}, inplace=True) + fine_test.rename(columns={args.label_name: "label"}, inplace=True) + + # Split data fine_train, fine_val = train_test_split( fine_tune, - test_size=args.test_size, + test_size=args.val_size, random_state=args.seed, stratify=fine_tune["label"], ) - tokenizer = ConceptTokenizer() + # Train Tokenizer + tokenizer = ConceptTokenizer(data_dir=args.vocab_dir) tokenizer.fit_on_vocab() + # Load datasets train_dataset = FinetuneDataset( data=fine_train, tokenizer=tokenizer, max_len=args.max_len, ) + args.dataset_len = len(train_dataset) val_dataset = FinetuneDataset( data=fine_val, @@ -74,20 +90,23 @@ def main(args): train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, + persistent_workers=args.persistent_workers, shuffle=True, - pin_memory=True, + pin_memory=args.pin_memory, ) + val_loader = DataLoader( val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - pin_memory=True, + pin_memory=args.pin_memory, ) + test_loader = DataLoader( test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - pin_memory=True, + pin_memory=args.pin_memory, ) callbacks = [ @@ -101,132 +120,242 @@ def main(args): dirpath=args.checkpoint_dir, ), LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor="val_loss", patience=5, verbose=True, mode="min"), + EarlyStopping( + monitor="val_loss", + patience=args.patience, + verbose=True, + mode="min", + ), ] + + # Create model + if args.model_type == "cehr_bert": + pretrained_model = BertPretrain( + args=args, + vocab_size=tokenizer.get_vocab_size(), + padding_idx=tokenizer.get_pad_token_id(), + **pre_model_config, + ) + pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"]) + + model = BertFinetune( + args=args, + pretrained_model=pretrained_model, + **fine_model_config, + ) + + elif args.model_type == "cehr_bigbird": + pretrained_model = BigBirdPretrain( + args=args, + vocab_size=tokenizer.get_vocab_size(), + padding_idx=tokenizer.get_pad_token_id(), + **pre_model_config, + ) + + pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"]) + + model = BigBirdFinetune( + args=args, + pretrained_model=pretrained_model, + **fine_model_config, + ) + + latest_checkpoint = get_latest_checkpoint(args.checkpoint_dir) + + run_id = get_run_id(args.checkpoint_dir) + wandb_logger = WandbLogger( - project="finetune", + project=args.exp_name, save_dir=args.log_dir, + entity=args.workspace_name, + id=run_id, + resume="allow", ) + + # Setup PyTorchLightning trainer trainer = pl.Trainer( accelerator="gpu", devices=args.gpus, strategy=DDPStrategy(find_unused_parameters=True) if args.gpus > 1 else "auto", - precision=16, + precision="16-mixed", check_val_every_n_epoch=1, max_epochs=args.max_epochs, callbacks=callbacks, + deterministic=False, + enable_checkpointing=True, + enable_progress_bar=True, + enable_model_summary=True, logger=wandb_logger, log_every_n_steps=args.log_every_n_steps, + accumulate_grad_batches=args.acc, + gradient_clip_val=1.0, ) - pretrained_model = BertPretrain( - vocab_size=tokenizer.get_vocab_size(), - padding_idx=tokenizer.get_pad_token_id(), - ) - pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"]) - - model = BertFinetune( - pretrained_model=pretrained_model, - ) - + # Train the model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=latest_checkpoint if latest_checkpoint else None, ) - trainer.test( - model=model, - dataloaders=test_loader, - ) + # Test the model + if args.test_last: + trainer.test( + dataloaders=test_loader, + ckpt_path="last", + ) + else: + trainer.test( + dataloaders=test_loader, + ckpt_path="best", + ) if __name__ == "__main__": parser = argparse.ArgumentParser() + # project configuration parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for reproducibility", + "--model-type", + type=str, + required=True, + help="Model type: 'cehr_bert' or 'cehr_bigbird'", ) parser.add_argument( - "--resume", - action="store_true", - help="Flag to resume training from a checkpoint", + "--exp-name", + type=str, + required=True, + help="Path to model config file", + ) + parser.add_argument( + "--pretrained-path", + type=str, + required=True, + help="Pretrained model", + ) + parser.add_argument( + "--label-name", + type=str, + required=True, + help="Name of the label column", + ) + parser.add_argument( + "--workspace-name", + type=str, + default=None, + help="Name of the Wandb workspace", + ) + parser.add_argument( + "--config-dir", + type=str, + default="models/configs", + help="Path to model config file", ) + + # data-related arguments parser.add_argument( - "--data_dir", + "--data-dir", type=str, default="data_files", help="Path to the data directory", ) parser.add_argument( - "--train_size", - type=float, - default=0.3, - help="Train set size for splitting the data", + "--sequence-file", + type=str, + default="patient_sequences_2048_labeled.parquet", + help="Path to the patient sequence file", ) parser.add_argument( - "--test_size", - type=float, - default=0.6, - help="Test set size for splitting the data", + "--id-file", + type=str, + default="dataset_2048_mortality_1month.pkl", + help="Path to the patient id file", ) parser.add_argument( - "--max_len", - type=int, - default=512, - help="Maximum length of the sequence", + "--vocab-dir", + type=str, + default="data_files/vocab", + help="Path to the vocabulary directory of json files", ) parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size for training", + "--val-size", + type=float, + default=0.1, + help="Validation set size for splitting the data", ) parser.add_argument( - "--num_workers", - type=int, - default=4, - help="Number of workers for training", + "--valid_scheme", + type=str, + default="few_shot", + help="Define the type of validation, few_shot or kfold", ) parser.add_argument( - "--checkpoint_dir", + "--num_finetune_patients", type=str, - default="checkpoints/finetuning", - help="Path to the training checkpoint", + default="20000_patients", + help="Define the number of patients to be fine_tuned on", ) + + # checkpointing and loggig arguments parser.add_argument( - "--log_dir", + "--checkpoint-dir", type=str, - default="logs", - help="Path to the log directory", + default="checkpoints", + help="Path to the checkpoint directory", ) parser.add_argument( - "--gpus", - type=int, - default=1, - help="Number of gpus for training", + "--log-dir", + type=str, + default="logs", + help="Path to the log directory", ) parser.add_argument( - "--max_epochs", - type=int, - default=10, - help="Number of epochs for training", + "--checkpoint-path", + type=str, + default=None, + help="Checkpoint to resume finetuning from", ) + parser.add_argument( "--log_every_n_steps", type=int, default=10, help="Number of steps to log the training", ) + + # other arguments parser.add_argument( - "--pretrained_path", - type=str, - default=None, - required=True, - help="Checkpoint to the pretrained model", + "--test-last", + action="store_true", + help="Test the last checkpoint", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", ) args = parser.parse_args() - main(args) + + if args.model_type not in ["cehr_bert", "cehr_bigbird"]: + print("Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird'.") + sys.exit(1) + + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.exp_name) + os.makedirs(args.checkpoint_dir, exist_ok=True) + os.makedirs(args.log_dir, exist_ok=True) + + config = load_config(args.config_dir, args.model_type) + + finetune_config = config["finetune"] + for key, value in finetune_config.items(): + if not hasattr(args, key) or getattr(args, key) is None: + setattr(args, key, value) + + pre_model_config = config["model"] + args.max_len = pre_model_config["max_seq_length"] + + fine_model_config = config["model_finetune"] + + main(args, pre_model_config, fine_model_config) diff --git a/finetune_bigbird.py b/finetune_bigbird.py deleted file mode 100644 index 87c7f7b..0000000 --- a/finetune_bigbird.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -File: finetune_bigbird.py. - -Finetune an already pretrained bigbird model on MIMIC-IV FHIR data. -The finetuning objective is binary classification on patient mortality or -hospital readmission labels. -""" - -import argparse -import glob -import os -from os.path import join -from typing import Any, Dict - -import numpy as np -import pandas as pd -import pytorch_lightning as pl -import torch -from lightning.pytorch.loggers import WandbLogger -from pytorch_lightning.callbacks import ( - EarlyStopping, - LearningRateMonitor, - ModelCheckpoint, -) -from pytorch_lightning.strategies.ddp import DDPStrategy -from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader - -from models.big_bird_cehr.data import FinetuneDataset -from models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain -from models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer - - -def seed_everything(seed: int) -> None: - """Seed all components of the model.""" - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - pl.seed_everything(seed) - - -def get_latest_checkpoint(checkpoint_dir: str) -> Any: - """Return the most recent checkpointed file to resume training from.""" - list_of_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt")) - return max(list_of_files, key=os.path.getctime) if list_of_files else None - - -def main(args: Dict[str, Any]) -> None: - """Train the model.""" - # Setup environment - seed_everything(args.seed) - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - torch.cuda.empty_cache() - torch.set_float32_matmul_precision("medium") - - # Load data - fine_tune = pd.read_parquet(join(args.data_dir, "fine_tune.parquet")) - fine_test = pd.read_parquet(join(args.data_dir, "fine_test.parquet")) - - # Split data - fine_train, fine_val = train_test_split( - fine_tune, - test_size=args.test_size, - random_state=args.seed, - stratify=fine_tune["label"], - ) - - # Train Tokenizer - tokenizer = HuggingFaceConceptTokenizer(data_dir=args.data_dir) - tokenizer.fit_on_vocab() - - # Load datasets - train_dataset = FinetuneDataset( - data=fine_train, - tokenizer=tokenizer, - max_len=args.max_len, - ) - - val_dataset = FinetuneDataset( - data=fine_val, - tokenizer=tokenizer, - max_len=args.max_len, - ) - - test_dataset = FinetuneDataset( - data=fine_test, - tokenizer=tokenizer, - max_len=args.max_len, - ) - - train_loader = DataLoader( - train_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=True, - shuffle=True, - pin_memory=True, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=True, - pin_memory=True, - ) - - test_loader = DataLoader( - test_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=True, - pin_memory=True, - ) - - # Setup model dependencies - callbacks = [ - ModelCheckpoint( - monitor="val_loss", - mode="min", - filename="best", - save_top_k=1, - save_last=True, - verbose=True, - dirpath=args.checkpoint_dir, - ), - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor="val_loss", patience=5, verbose=True, mode="min"), - ] - - wandb_logger = WandbLogger( - project="bigbird_finetune", - save_dir=args.log_dir, - ) - - # Load latest checkpoint to continue training - latest_checkpoint = get_latest_checkpoint(args.checkpoint_path) - - # Setup PyTorchLightning trainer - trainer = pl.Trainer( - accelerator="gpu", - devices=args.gpus, - strategy=DDPStrategy(find_unused_parameters=True) if args.gpus > 1 else "auto", - precision="16-mixed", - check_val_every_n_epoch=1, - max_epochs=args.max_epochs, - callbacks=callbacks, - deterministic=False, - enable_checkpointing=True, - enable_progress_bar=True, - enable_model_summary=True, - logger=wandb_logger, - resume_from_checkpoint=latest_checkpoint if args.resume else None, - log_every_n_steps=args.log_every_n_steps, - accumulate_grad_batches=args.acc, - gradient_clip_val=1.0, - ) - - # Create pretrain BigBird model and load the pretrained state_dict - pretrained_model = BigBirdPretrain( - args=args, - dataset_len=len(train_dataset), - vocab_size=tokenizer.get_vocab_size(), - padding_idx=tokenizer.get_pad_token_id(), - ) - pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"]) - - # Create fine tune BigBird model - model = BigBirdFinetune( - args, - dataset_len=len(train_dataset), - pretrained_model=pretrained_model, - ) - - # Train the model - trainer.fit( - model=model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, - ) - - # Test the model - trainer.test( - model=model, - dataloaders=test_loader, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for reproducibility", - ) - parser.add_argument( - "--resume", - action="store_true", - default=False, - help="Flag to resume training from a checkpoint", - ) - parser.add_argument( - "--data_dir", - type=str, - default="/h/afallah/odyssey/odyssey/data/slurm_data/one_month", - help="Path to the data directory", - ) - parser.add_argument( - "--train_size", - type=float, - default=0.3, - help="Train set size for splitting the data", - ) - parser.add_argument( - "--test_size", - type=float, - default=0.6, - help="Test set size for splitting the data", - ) - parser.add_argument( - "--max_len", - type=int, - default=2048, - help="Maximum length of the sequence", - ) - parser.add_argument( - "--batch_size", - type=int, - default=16, - help="Batch size for training", - ) - parser.add_argument( - "--num_workers", - type=int, - default=4, - help="Number of workers for training", - ) - parser.add_argument( - "--checkpoint_dir", - type=str, - default="checkpoints/finetuning", - help="Path to the training checkpoint", - ) - parser.add_argument( - "--log_dir", - type=str, - default="logs", - help="Path to the log directory", - ) - parser.add_argument( - "--gpus", - type=int, - default=1, - help="Number of gpus for training", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=10, - help="Number of epochs for training", - ) - parser.add_argument( - "--acc", - type=int, - default=1, - help="Gradient accumulation", - ) - parser.add_argument( - "--log_every_n_steps", - type=int, - default=10, - help="Number of steps to log the training", - ) - parser.add_argument( - "--pretrained_path", - type=str, - default=None, - required=True, - help="Checkpoint to the pretrained model", - ) - - args = parser.parse_args() - main(args) diff --git a/lib/data.py b/lib/data.py new file mode 100644 index 0000000..dfc37ef --- /dev/null +++ b/lib/data.py @@ -0,0 +1,170 @@ +""" +data.py. + +Create custom pretrain and finetune PyTorch Dataset objects for MIMIC-IV FHIR dataset. +""" + +from typing import Any, Dict, List, Tuple, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset + +from .tokenizer import ConceptTokenizer + + +class PretrainDataset(Dataset): + """Dataset for pretraining the model.""" + + def __init__( + self, + data: pd.DataFrame, + tokenizer: ConceptTokenizer, + max_len: int = 2048, + mask_prob: float = 0.15, + ): + """Initiate the class.""" + super(PretrainDataset, self).__init__() + + self.data = data + self.tokenizer = tokenizer + self.max_len = max_len + self.mask_prob = mask_prob + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.data) + + def tokenize_data(self, sequence: Union[str, List[str]]) -> Any: + """Tokenize the sequence and return input_ids and attention mask.""" + return self.tokenizer(sequence, max_length=self.max_len) + + def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mask the tokens in the sequence using vectorized operations.""" + mask_token_id = self.tokenizer.get_mask_token_id() + + masked_sequence = sequence.clone() + + # Ignore [PAD], [UNK], [MASK] tokens + prob_matrix = torch.full(masked_sequence.shape, self.mask_prob) + prob_matrix[torch.where(masked_sequence <= mask_token_id)] = 0 + selected = torch.bernoulli(prob_matrix).bool() + + # 80% of the time, replace masked input tokens with respective mask tokens + replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected + masked_sequence[replaced] = mask_token_id + + # 10% of the time, we replace masked input tokens with random vector. + randomized = ( + torch.bernoulli(torch.full(selected.shape, 0.1)).bool() + & selected + & ~replaced + ) + random_idx = torch.randint( + low=self.tokenizer.get_first_token_index(), + high=self.tokenizer.get_last_token_index(), + size=prob_matrix.shape, + dtype=torch.long, + ) + masked_sequence[randomized] = random_idx[randomized] + + labels = torch.where(selected, sequence, -100) + + return masked_sequence, labels + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get data at corresponding index. + + Return it as a dictionary including + all different token sequences along with attention mask and labels. + """ + data = self.data.iloc[idx] + tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) + concept_tokens = tokenized_input["input_ids"].squeeze() + attention_mask = tokenized_input["attention_mask"].squeeze() + + type_tokens = data[f"type_tokens_{self.max_len}"] + age_tokens = data[f"age_tokens_{self.max_len}"] + time_tokens = data[f"time_tokens_{self.max_len}"] + visit_tokens = data[f"visit_tokens_{self.max_len}"] + position_tokens = data[f"position_tokens_{self.max_len}"] + + masked_tokens, labels = self.mask_tokens(concept_tokens) + + type_tokens = torch.tensor(type_tokens) + age_tokens = torch.tensor(age_tokens) + time_tokens = torch.tensor(time_tokens) + visit_tokens = torch.tensor(visit_tokens) + position_tokens = torch.tensor(position_tokens) + + return { + "concept_ids": masked_tokens, + "type_ids": type_tokens, + "ages": age_tokens, + "time_stamps": time_tokens, + "visit_orders": position_tokens, + "visit_segments": visit_tokens, + "labels": labels, + "attention_mask": attention_mask, + } + + +class FinetuneDataset(Dataset): + """Dataset for finetuning the model.""" + + def __init__( + self, + data: pd.DataFrame, + tokenizer: ConceptTokenizer, + max_len: int = 2048, + ): + """Initiate the class.""" + super(FinetuneDataset, self).__init__() + + self.data = data + self.tokenizer = tokenizer + self.max_len = max_len + + def __len__(self) -> int: + """Return the length of dataset.""" + return len(self.data) + + def tokenize_data(self, sequence: Union[str, List[str]]) -> Any: + """Tokenize the sequence and return input_ids and attention mask.""" + return self.tokenizer(sequence, max_length=self.max_len) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Get data at corresponding index. + + Return it as a dictionary including + all different token sequences along with attention mask and labels. + """ + data = self.data.iloc[idx] + tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) + concept_tokens = tokenized_input["input_ids"].squeeze() + attention_mask = tokenized_input["attention_mask"].squeeze() + + type_tokens = data[f"type_tokens_{self.max_len}"] + age_tokens = data[f"age_tokens_{self.max_len}"] + time_tokens = data[f"time_tokens_{self.max_len}"] + visit_tokens = data[f"visit_tokens_{self.max_len}"] + position_tokens = data[f"position_tokens_{self.max_len}"] + labels = data["label"] + + type_tokens = torch.tensor(type_tokens) + age_tokens = torch.tensor(age_tokens) + time_tokens = torch.tensor(time_tokens) + visit_tokens = torch.tensor(visit_tokens) + position_tokens = torch.tensor(position_tokens) + labels = torch.tensor(labels) + + return { + "concept_ids": concept_tokens, + "type_ids": type_tokens, + "ages": age_tokens, + "time_stamps": time_tokens, + "visit_orders": position_tokens, + "visit_segments": visit_tokens, + "labels": labels, + "attention_mask": attention_mask, + } diff --git a/lib/tokenizer.py b/lib/tokenizer.py new file mode 100644 index 0000000..a2f9619 --- /dev/null +++ b/lib/tokenizer.py @@ -0,0 +1,192 @@ +""" +file: tokenizer.py. + +Custom HuggingFace tokenizer for medical concepts in MIMIC-IV FHIR dataset. +""" + +import glob +import json +import os +from itertools import chain +from typing import Any, Dict, List, Optional, Set, Union + +from tokenizers import Tokenizer, models, pre_tokenizers +from transformers import BatchEncoding, PreTrainedTokenizerFast + + +class ConceptTokenizer: + """Tokenizer for event concepts using HuggingFace Library.""" + + def __init__( + self, + pad_token: str = "[PAD]", + mask_token: str = "[MASK]", + start_token: str = "[VS]", + end_token: str = "[VE]", + class_token: str = "[CLS]", + reg_token: str = "[REG]", + unknown_token: str = "[UNK]", + data_dir: str = "data_files", + tokenizer_object: Optional[Tokenizer] = None, + tokenizer: Optional[PreTrainedTokenizerFast] = None, + ) -> None: + self.mask_token = mask_token + self.pad_token = pad_token + self.unknown_token = unknown_token + self.special_tokens = ( + [ + pad_token, + unknown_token, + mask_token, + start_token, + end_token, + class_token, + reg_token, + ] + + [f"[W_{i}]" for i in range(0, 4)] + + [f"[M_{i}]" for i in range(0, 13)] + + ["[LT]"] + ) + + self.tokenizer_object = tokenizer_object + self.tokenizer = tokenizer + + self.tokenizer_vocab: Dict[str, int] = {} + self.token_type_vocab: Dict[str, Any] = {} + self.data_dir = data_dir + + self.special_token_ids: List[int] = [] + self.first_token_index: Optional[int] = None + self.last_token_index: Optional[int] = None + + def fit_on_vocab(self) -> None: + """Fit the tokenizer on the vocabulary.""" + # Create dictionary of all possible medical concepts + self.token_type_vocab["special_tokens"] = self.special_tokens + vocab_json_files = glob.glob(os.path.join(self.data_dir, "*_vocab.json")) + + for file in vocab_json_files: + with open(file, "r") as vocab_file: + vocab = json.load(vocab_file) + vocab_type = file.split("/")[-1].split(".")[0] + self.token_type_vocab[vocab_type] = vocab + + # Create the tokenizer dictionary + tokens = list(chain.from_iterable(list(self.token_type_vocab.values()))) + self.tokenizer_vocab = {token: i for i, token in enumerate(tokens)} + + # Create the tokenizer object + self.tokenizer_object = Tokenizer( + models.WordPiece( + vocab=self.tokenizer_vocab, + unk_token=self.unknown_token, + max_input_chars_per_word=1000, + ), + ) + self.tokenizer_object.pre_tokenizer = pre_tokenizers.WhitespaceSplit() + self.tokenizer = self.create_tokenizer(self.tokenizer_object) + + # Get the first, last , and special token indexes from the dictionary + self.first_token_index = self.get_first_token_index() + self.last_token_index = self.get_last_token_index() + self.special_token_ids = self.get_special_token_ids() + + # Check to make sure tokenizer follows the same vocabulary + assert ( + self.tokenizer_vocab == self.tokenizer.get_vocab() + ), "Tokenizer vocabulary does not match original" + + def create_tokenizer( + self, + tokenizer_obj: Tokenizer, + ) -> PreTrainedTokenizerFast: + """Load the tokenizer from a JSON file on disk.""" + self.tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer_obj, + bos_token="[VS]", + eos_token="[VE]", + unk_token="[UNK]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + ) + return self.tokenizer + + def __call__( + self, + batch: Union[str, List[str]], + return_attention_mask: bool = True, + return_token_type_ids: bool = False, + truncation: bool = False, + padding: str = "max_length", + max_length: int = 2048, + ) -> BatchEncoding: + """Return the tokenized dictionary of input batch.""" + return self.tokenizer( + batch, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + truncation=truncation, + padding=padding, + max_length=max_length, + return_tensors="pt", + ) + + def encode(self, concept_tokens: str) -> List[int]: + """Encode the concept tokens into token ids.""" + return self.tokenizer_object.encode(concept_tokens).ids + + def decode(self, concept_ids: List[int]) -> str: + """Decode the concept sequence token id into token concept.""" + return self.tokenizer_object.decode(concept_ids) + + def token_to_id(self, token: str) -> int: + """Return the id corresponding to token.""" + return self.tokenizer_object.token_to_id(token) + + def id_to_token(self, token_id: int) -> str: + """Return the token corresponding to id.""" + return self.tokenizer_object.id_to_token(token_id) + + def get_all_token_indexes(self, with_special_tokens: bool = True) -> Set[int]: + """Return a set of all possible token ids.""" + all_token_ids = set(self.tokenizer_vocab.values()) + special_token_ids = set(self.get_special_token_ids()) + + return ( + all_token_ids if with_special_tokens else all_token_ids - special_token_ids + ) + + def get_first_token_index(self) -> int: + """Return the smallest token id in vocabulary.""" + return min(self.tokenizer_vocab.values()) + + def get_last_token_index(self) -> int: + """Return the largest token id in vocabulary.""" + return max(self.tokenizer_vocab.values()) + + def get_vocab_size(self) -> int: + """Return the number of possible tokens in vocabulary.""" + return len(self.tokenizer) + + def get_pad_token_id(self) -> int: + """Return the token id of PAD token.""" + return self.token_to_id(self.pad_token) + + def get_mask_token_id(self) -> int: + """Return the token id of MASK token.""" + return self.token_to_id(self.mask_token) + + def get_special_token_ids(self) -> List[int]: + """Get a list of ids representing special tokens.""" + self.special_token_ids = [] + + for special_token in self.special_tokens: + special_token_id = self.token_to_id(special_token) + self.special_token_ids.append(special_token_id) + + return self.special_token_ids + + def save_tokenizer_to_disk(self, save_dir: str) -> None: + """Save the tokenizer object to disk as a JSON file.""" + self.tokenizer.save(path=save_dir) diff --git a/lib/utils.py b/lib/utils.py new file mode 100644 index 0000000..a2ac8f3 --- /dev/null +++ b/lib/utils.py @@ -0,0 +1,111 @@ +"""Utility functions for the model training and finetuning.""" + +import glob +import os +import pickle +import uuid +from os.path import join +from typing import Any + +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +import yaml + + +def load_config(config_dir: str, model_type: str) -> Any: + """Load the model configuration.""" + config_file = join(config_dir, f"{model_type}.yaml") + with open(config_file, "r") as file: + return yaml.safe_load(file) + + +def seed_everything(seed: int) -> None: + """Seed all components of the model.""" + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + pl.seed_everything(seed) + + +def get_latest_checkpoint(checkpoint_dir: str) -> Any: + """Return the most recent checkpointed file to resume training from.""" + list_of_files = glob.glob(join(checkpoint_dir, "last*.ckpt")) + return max(list_of_files, key=os.path.getctime) if list_of_files else None + + +def load_pretrain_data( + data_dir: str, + sequence_file: str, + id_file: str, +) -> pd.DataFrame: + """Load the pretraining data.""" + sequence_path = join(data_dir, sequence_file) + id_path = join(data_dir, id_file) + + if not os.path.exists(sequence_path): + raise FileNotFoundError(f"Sequence file not found: {sequence_path}") + + if not os.path.exists(id_path): + raise FileNotFoundError(f"ID file not found: {id_path}") + + data = pd.read_parquet(sequence_path) + with open(id_path, "rb") as file: + patient_ids = pickle.load(file) + + return data.loc[data["patient_id"].isin(patient_ids["pretrain"])] + + +def load_finetune_data( + data_dir: str, + sequence_file: str, + id_file: str, + valid_scheme: str, + num_finetune_patients: int, +) -> pd.DataFrame: + """Load the finetuning data.""" + sequence_path = join(data_dir, sequence_file) + id_path = join(data_dir, id_file) + + if not os.path.exists(sequence_path): + raise FileNotFoundError(f"Sequence file not found: {sequence_path}") + + if not os.path.exists(id_path): + raise FileNotFoundError(f"ID file not found: {id_path}") + + data = pd.read_parquet(sequence_path) + with open(id_path, "rb") as file: + patient_ids = pickle.load(file) + + fine_tune = data.loc[ + data["patient_id"].isin( + patient_ids["valid"][valid_scheme][num_finetune_patients], + ) + ] + fine_test = data.loc[data["patient_id"].isin(patient_ids["test"])] + return fine_tune, fine_test + + +def get_run_id( + checkpoint_dir: str, + retrieve: bool = False, + run_id_file: str = "wandb_run_id.txt", + length: int = 8, +) -> str: + """ + Return the run ID for the current run. + + If the run ID file exists, retrieve the run ID from the file. + """ + run_id_path = os.path.join(checkpoint_dir, run_id_file) + if retrieve and os.path.exists(run_id_path): + with open(run_id_path, "r") as file: + run_id = file.read().strip() + else: + run_id = str(uuid.uuid4())[:length] + with open(run_id_path, "w") as file: + file.write(run_id) + return run_id diff --git a/models/cehr_bert/data.py b/models/cehr_bert/data.py deleted file mode 100644 index 383f901..0000000 --- a/models/cehr_bert/data.py +++ /dev/null @@ -1,163 +0,0 @@ -import random -from typing import Sequence, Union - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset - -from models.cehr_bert.tokenizer import ConceptTokenizer - - -class PretrainDataset(Dataset): - """Dataset for pretraining the model.""" - - def __init__( - self, - data: pd.DataFrame, - tokenizer: ConceptTokenizer, - max_len: int = 512, - mask_prob: float = 0.15, - ): - self.data = data - self.tokenizer = tokenizer - self.max_len = max_len # TODO: max_len is not used - self.mask_prob = mask_prob - - def __len__(self): - """Return the length of the dataset.""" - return len(self.data) - - def tokenize_data(self, sequence: Union[str, Sequence[str]]) -> np.ndarray: - """Tokenize the sequence.""" - tokenized = self.tokenizer.encode(sequence) - tokenized = np.array(tokenized).flatten() - return tokenized - - def get_attention_mask(self, sequence: np.ndarray) -> np.ndarray: - """Get the attention mask for the sequence.""" - attention_mask = [ - float(token != self.tokenizer.get_pad_token_id()) for token in sequence - ] - return attention_mask - - def mask_tokens(self, sequence: np.ndarray) -> tuple: - """Mask the tokens in the sequence.""" - masked_sequence = [] - labels = [] - for token in sequence: - if token in self.tokenizer.get_special_token_ids(): - masked_sequence.append(token) - labels.append(-100) - continue - prob = random.random() - if prob < self.mask_prob: - dice = random.random() - if dice < 0.8: - masked_sequence.append(self.tokenizer.get_mask_token_id()) - elif dice < 0.9: - random_token = random.randint( - self.tokenizer.get_first_token_index(), - self.tokenizer.get_last_token_index(), - ) - masked_sequence.append(random_token) - else: - masked_sequence.append(token) - labels.append(token) - else: - masked_sequence.append(token) - labels.append(-100) - return masked_sequence, labels - - def __getitem__(self, idx: int) -> dict: - data = self.data.iloc[idx] - concept_tokens = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) - type_tokens = data[f"type_tokens_{self.max_len}"] - age_tokens = data[f"age_tokens_{self.max_len}"] - time_tokens = data[f"time_tokens_{self.max_len}"] - visit_tokens = data[f"visit_tokens_{self.max_len}"] - position_tokens = data[f"position_tokens_{self.max_len}"] - - attention_mask = self.get_attention_mask(concept_tokens) - masked_tokens, labels = self.mask_tokens(concept_tokens) - - masked_tokens = torch.tensor(masked_tokens) - type_tokens = torch.tensor(type_tokens) - age_tokens = torch.tensor(age_tokens) - time_tokens = torch.tensor(time_tokens) - visit_tokens = torch.tensor(visit_tokens) - position_tokens = torch.tensor(position_tokens) - labels = torch.tensor(labels) - attention_mask = torch.tensor(attention_mask) - - return { - "concept_ids": masked_tokens, - "type_ids": type_tokens, - "ages": age_tokens, - "time_stamps": time_tokens, - "visit_orders": position_tokens, - "visit_segments": visit_tokens, - "labels": labels, - "attention_mask": attention_mask, - } - - -class FinetuneDataset(Dataset): - """Dataset for finetuning the model.""" - - def __init__( - self, - data: pd.DataFrame, - tokenizer: ConceptTokenizer, - max_len: int = 512, - ): - self.data = data - self.tokenizer = tokenizer - self.max_len = max_len - - def __len__(self) -> int: - return len(self.data) - - def tokenize_data(self, sequence): - """Tokenize the sequence.""" - tokenized = self.tokenizer.encode(sequence) - tokenized = np.array(tokenized).flatten() - return tokenized - - def get_attention_mask(self, sequence): - """Get the attention mask for the sequence.""" - attention_mask = [ - float(token != self.tokenizer.get_pad_token_id()) for token in sequence - ] - return attention_mask - - def __getitem__(self, idx): - data = self.data.iloc[idx] - concept_tokens = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) - type_tokens = data[f"type_tokens_{self.max_len}"] - age_tokens = data[f"age_tokens_{self.max_len}"] - time_tokens = data[f"time_tokens_{self.max_len}"] - visit_tokens = data[f"visit_tokens_{self.max_len}"] - position_tokens = data[f"position_tokens_{self.max_len}"] - labels = data["label"] - attention_mask = self.get_attention_mask(concept_tokens) - - concept_tokens = torch.tensor(concept_tokens) - type_tokens = torch.tensor(type_tokens) - age_tokens = torch.tensor(age_tokens) - time_tokens = torch.tensor(time_tokens) - visit_tokens = torch.tensor(visit_tokens) - position_tokens = torch.tensor(position_tokens) - labels = torch.tensor(labels) - attention_mask = torch.tensor(attention_mask) - - return { - "concept_ids": concept_tokens, - "type_ids": type_tokens, - "ages": age_tokens, - "time_stamps": time_tokens, - "visit_orders": position_tokens, - "visit_segments": visit_tokens, - "labels": labels, - "attention_mask": attention_mask, - } diff --git a/models/cehr_bert/tokenizer.py b/models/cehr_bert/tokenizer.py deleted file mode 100644 index 6636e44..0000000 --- a/models/cehr_bert/tokenizer.py +++ /dev/null @@ -1,101 +0,0 @@ -import glob -import json -import os -from typing import Sequence, Union - -from keras.preprocessing.text import Tokenizer - - -class ConceptTokenizer: - """Tokenizer for event concepts.""" - - def __init__( - self, - pad_token: str = "[PAD]", - mask_token: str = "[MASK]", - start_token: str = "[VS]", - end_token: str = "[VE]", - class_token: str = "[CLS]", - oov_token="-1", - data_dir: str = "data_files", - ): - self.tokenizer = Tokenizer(oov_token=oov_token, filters="", lower=False) - self.mask_token = mask_token - self.pad_token = pad_token - self.special_tokens = ( - [pad_token, mask_token, start_token, end_token, class_token] - + [f"[W_{i}]" for i in range(0, 4)] - + [f"[M_{i}]" for i in range(0, 13)] - + ["[LT]"] - ) - self.data_dir = data_dir - - def fit_on_vocab(self) -> None: - """Fit the tokenizer on the vocabulary.""" - vocab_json_files = glob.glob(os.path.join(self.data_dir, "*_vocab.json")) - for file in vocab_json_files: - vocab = json.load(open(file, "r")) - self.tokenizer.fit_on_texts(vocab) - self.tokenizer.fit_on_texts(self.special_tokens) - - def encode( - self, - concept_sequences: Union[str, Sequence[str]], - is_generator: bool = False, - ) -> Union[int, Sequence[int]]: - """Encode the concept sequences into token ids.""" - return ( - self.tokenizer.texts_to_sequences_generator(concept_sequences) - if is_generator - else self.tokenizer.texts_to_sequences(concept_sequences) - ) - - def decode( - self, - concept_sequence_token_ids: Union[int, Sequence[int]], - ) -> Sequence[str]: - """Decode the concept sequence token ids into concepts.""" - return self.tokenizer.sequences_to_texts(concept_sequence_token_ids) - - def get_all_token_indexes(self) -> set: - all_keys = set(self.tokenizer.index_word.keys()) - - if self.tokenizer.oov_token is not None: - all_keys.remove(self.tokenizer.word_index[self.tokenizer.oov_token]) - - if self.special_tokens is not None: - excluded = set( - [ - self.tokenizer.word_index[special_token] - for special_token in self.special_tokens - ], - ) - all_keys = all_keys - excluded - return all_keys - - def get_first_token_index(self) -> int: - return min(self.get_all_token_indexes()) - - def get_last_token_index(self) -> int: - return max(self.get_all_token_indexes()) - - def get_vocab_size(self) -> int: - # + 1 because oov_token takes the index 0 - return len(self.tokenizer.index_word) + 1 - - def get_pad_token_id(self): - pad_token_id = self.encode(self.pad_token) - while isinstance(pad_token_id, list): - pad_token_id = pad_token_id[0] - return pad_token_id - - def get_mask_token_id(self): - mask_token_id = self.encode(self.mask_token) - while isinstance(mask_token_id, list): - mask_token_id = mask_token_id[0] - return mask_token_id - - def get_special_token_ids(self): - special_ids = self.encode(self.special_tokens) - flat_special_ids = [item[0] for item in special_ids] - return flat_special_ids diff --git a/models/configs/cehr_bert.yaml b/models/configs/cehr_bert.yaml new file mode 100644 index 0000000..c860340 --- /dev/null +++ b/models/configs/cehr_bert.yaml @@ -0,0 +1,42 @@ +model: + embedding_size: 768 + time_embeddings_size: 32 + type_vocab_size: 9 + max_seq_length: 512 + depth: 5 + num_heads: 8 + intermediate_size: 3072 + learning_rate: 5.e-5 + eta_min: 1.e-8 + num_iterations: 10 + increase_factor: 2 + dropout_prob: 0.1 + use_adamw: True +model_finetune: + num_labels: 2 + hidden_size: 768 + classifier_dropout: 0.1 + hidden_dropout_prob: 0.1 + learning_rate: 5.e-6 + eta_min: 1.e-8 + num_iterations: 10 + increase_factor: 2 + use_adamw: True +train: + batch_size: 32 + num_workers: 4 + gpus: 1 + max_epochs: 30 + acc: 1 + mask_prob: 0.15 + persistent_workers: True + pin_memory: True +finetune: + batch_size: 32 + num_workers: 4 + gpus: 1 + max_epochs: 5 + acc: 1 + patience: 2 + persistent_workers: True + pin_memory: True diff --git a/models/configs/cehr_bigbird.yaml b/models/configs/cehr_bigbird.yaml new file mode 100644 index 0000000..4288847 --- /dev/null +++ b/models/configs/cehr_bigbird.yaml @@ -0,0 +1,36 @@ +model: + embedding_size: 768 + time_embeddings_size: 32 + visit_order_size: 3 + type_vocab_size: 8 + max_seq_length: 2048 + depth: 6 + num_heads: 12 + intermediate_size: 3072 + learning_rate: 5.e-5 + eta_min: 1.e-8 + num_iterations: 10 + increase_factor: 2 + dropout_prob: 0.1 +model_finetune: + num_labels: 2 + learning_rate: 5.e-6 + classifier_dropout: 0.1 +train: + batch_size: 12 + num_workers: 4 + gpus: 4 + max_epochs: 10 + acc: 1 + mask_prob: 0.15 + persistent_workers: True + pin_memory: True +finetune: + batch_size: 3 + num_workers: 2 + gpus: 1 + max_epochs: 5 + acc: 1 + patience: 5 + persistent_workers: True + pin_memory: True diff --git a/poetry.lock b/poetry.lock index 058bc5a..ed58e2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1319,38 +1319,38 @@ files = [ [[package]] name = "mypy" -version = "1.8.0" +version = "1.9.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, - {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, - {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, - {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, - {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, - {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, - {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, - {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, - {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, - {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, - {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, - {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, - {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, - {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, - {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, - {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, - {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, - {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, - {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8a67616990062232ee4c3952f41c779afac41405806042a8126fe96e098419f"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d357423fa57a489e8c47b7c85dfb96698caba13d66e086b412298a1a0ea3b0ed"}, + {file = "mypy-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49c87c15aed320de9b438ae7b00c1ac91cd393c1b854c2ce538e2a72d55df150"}, + {file = "mypy-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:48533cdd345c3c2e5ef48ba3b0d3880b257b423e7995dada04248725c6f77374"}, + {file = "mypy-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d3dbd346cfec7cb98e6cbb6e0f3c23618af826316188d587d1c1bc34f0ede03"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:653265f9a2784db65bfca694d1edd23093ce49740b2244cde583aeb134c008f3"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a3c007ff3ee90f69cf0a15cbcdf0995749569b86b6d2f327af01fd1b8aee9dc"}, + {file = "mypy-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2418488264eb41f69cc64a69a745fad4a8f86649af4b1041a4c64ee61fc61129"}, + {file = "mypy-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:68edad3dc7d70f2f17ae4c6c1b9471a56138ca22722487eebacfd1eb5321d612"}, + {file = "mypy-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:85ca5fcc24f0b4aeedc1d02f93707bccc04733f21d41c88334c5482219b1ccb3"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aceb1db093b04db5cd390821464504111b8ec3e351eb85afd1433490163d60cd"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0235391f1c6f6ce487b23b9dbd1327b4ec33bb93934aa986efe8a9563d9349e6"}, + {file = "mypy-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d5ddc13421ba3e2e082a6c2d74c2ddb3979c39b582dacd53dd5d9431237185"}, + {file = "mypy-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:190da1ee69b427d7efa8aa0d5e5ccd67a4fb04038c380237a0d96829cb157913"}, + {file = "mypy-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe28657de3bfec596bbeef01cb219833ad9d38dd5393fc649f4b366840baefe6"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e54396d70be04b34f31d2edf3362c1edd023246c82f1730bbf8768c28db5361b"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5e6061f44f2313b94f920e91b204ec600982961e07a17e0f6cd83371cb23f5c2"}, + {file = "mypy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a10926e5473c5fc3da8abb04119a1f5811a236dc3a38d92015cb1e6ba4cb9e"}, + {file = "mypy-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b685154e22e4e9199fc95f298661deea28aaede5ae16ccc8cbb1045e716b3e04"}, + {file = "mypy-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:5d741d3fc7c4da608764073089e5f58ef6352bedc223ff58f2f038c2c4698a89"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:587ce887f75dd9700252a3abbc9c97bbe165a4a630597845c61279cf32dfbf02"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f88566144752999351725ac623471661c9d1cd8caa0134ff98cceeea181789f4"}, + {file = "mypy-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61758fabd58ce4b0720ae1e2fea5cfd4431591d6d590b197775329264f86311d"}, + {file = "mypy-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e49499be624dead83927e70c756970a0bc8240e9f769389cdf5714b0784ca6bf"}, + {file = "mypy-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:571741dc4194b4f82d344b15e8837e8c5fcc462d66d076748142327626a1b6e9"}, + {file = "mypy-1.9.0-py3-none-any.whl", hash = "sha256:a260627a570559181a9ea5de61ac6297aa5af202f06fd7ab093ce74e7181e43e"}, + {file = "mypy-1.9.0.tar.gz", hash = "sha256:3cc5da0127e6a478cddd906068496a97a7618a21ce9b54bde5bf7e539c7af974"}, ] [package.dependencies] @@ -1398,13 +1398,13 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] [[package]] name = "nbqa" -version = "1.8.3" +version = "1.8.4" description = "Run any standard Python code quality tool on a Jupyter Notebook" optional = false python-versions = ">=3.8.0" files = [ - {file = "nbqa-1.8.3-py3-none-any.whl", hash = "sha256:54d174c785d604a95c188b027717cd0d92b217de9dd77374d78f6c49319e1dc7"}, - {file = "nbqa-1.8.3.tar.gz", hash = "sha256:985d252bf3fb56558b138ebd306f773a3f9c659aed5fc6f9be5601471e230225"}, + {file = "nbqa-1.8.4-py3-none-any.whl", hash = "sha256:0e2acd73320ad1aa56f15200f9ea517c0ecb5ac388d217aee97fab66272c604b"}, + {file = "nbqa-1.8.4.tar.gz", hash = "sha256:ca983e115d81f5cf149f70c4bf5b8baa36694a3eaf0783fe508dbf05b9767e07"}, ] [package.dependencies] @@ -1561,13 +1561,13 @@ wheel = "*" [[package]] name = "packaging" -version = "23.2" +version = "24.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] @@ -1694,6 +1694,21 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "plotly" +version = "5.19.0" +description = "An open-source, interactive data visualization library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "plotly-5.19.0-py3-none-any.whl", hash = "sha256:906abcc5f15945765328c5d47edaa884bc99f5985fbc61e8cd4dc361f4ff8f5a"}, + {file = "plotly-5.19.0.tar.gz", hash = "sha256:5ea91a56571292ade3e3bc9bf712eba0b95a1fb0a941375d978cc79432e055f4"}, +] + +[package.dependencies] +packaging = "*" +tenacity = ">=6.2.0" + [[package]] name = "pluggy" version = "1.4.0" @@ -1836,6 +1851,54 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyarrow" +version = "15.0.1" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:c2ddb3be5ea938c329a84171694fc230b241ce1b6b0ff1a0280509af51c375fa"}, + {file = "pyarrow-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7543ea88a0ff72f8e6baaf9bfdbec2c62aeabdbede9e4a571c71cc3bc43b6302"}, + {file = "pyarrow-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1519e218a6941fc074e4501088d891afcb2adf77c236e03c34babcf3d6a0d1c7"}, + {file = "pyarrow-15.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28cafa86e1944761970d3b3fc0411b14ff9b5c2b73cd22aaf470d7a3976335f5"}, + {file = "pyarrow-15.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:be5c3d463e33d03eab496e1af7916b1d44001c08f0f458ad27dc16093a020638"}, + {file = "pyarrow-15.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:47b1eda15d3aa3f49a07b1808648e1397e5dc6a80a30bf87faa8e2d02dad7ac3"}, + {file = "pyarrow-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:e524a31be7db22deebbbcf242b189063ab9a7652c62471d296b31bc6e3cae77b"}, + {file = "pyarrow-15.0.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:a476fefe8bdd56122fb0d4881b785413e025858803cc1302d0d788d3522b374d"}, + {file = "pyarrow-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:309e6191be385f2e220586bfdb643f9bb21d7e1bc6dd0a6963dc538e347b2431"}, + {file = "pyarrow-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83bc586903dbeb4365cbc72b602f99f70b96c5882e5dfac5278813c7d624ca3c"}, + {file = "pyarrow-15.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07e652daac6d8b05280cd2af31c0fb61a4490ec6a53dc01588014d9fa3fdbee9"}, + {file = "pyarrow-15.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:abad2e08652df153a72177ce20c897d083b0c4ebeec051239e2654ddf4d3c996"}, + {file = "pyarrow-15.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:cde663352bc83ad75ba7b3206e049ca1a69809223942362a8649e37bd22f9e3b"}, + {file = "pyarrow-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:1b6e237dd7a08482a8b8f3f6512d258d2460f182931832a8c6ef3953203d31e1"}, + {file = "pyarrow-15.0.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:7bd167536ee23192760b8c731d39b7cfd37914c27fd4582335ffd08450ff799d"}, + {file = "pyarrow-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c08bb31eb2984ba5c3747d375bb522e7e536b8b25b149c9cb5e1c49b0ccb736"}, + {file = "pyarrow-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0f9c1d630ed2524bd1ddf28ec92780a7b599fd54704cd653519f7ff5aec177a"}, + {file = "pyarrow-15.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5186048493395220550bca7b524420471aac2d77af831f584ce132680f55c3df"}, + {file = "pyarrow-15.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:31dc30c7ec8958da3a3d9f31d6c3630429b2091ede0ecd0d989fd6bec129f0e4"}, + {file = "pyarrow-15.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3f111a014fb8ac2297b43a74bf4495cc479a332908f7ee49cb7cbd50714cb0c1"}, + {file = "pyarrow-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:a6d1f7c15d7f68f08490d0cb34611497c74285b8a6bbeab4ef3fc20117310983"}, + {file = "pyarrow-15.0.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:9ad931b996f51c2f978ed517b55cb3c6078272fb4ec579e3da5a8c14873b698d"}, + {file = "pyarrow-15.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:738f6b53ab1c2f66b2bde8a1d77e186aeaab702d849e0dfa1158c9e2c030add3"}, + {file = "pyarrow-15.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c1c3fc16bc74e33bf8f1e5a212938ed8d88e902f372c4dac6b5bad328567d2f"}, + {file = "pyarrow-15.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1fa92512128f6c1b8dde0468c1454dd70f3bff623970e370d52efd4d24fd0be"}, + {file = "pyarrow-15.0.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:b4157f307c202cbbdac147d9b07447a281fa8e63494f7fc85081da351ec6ace9"}, + {file = "pyarrow-15.0.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:b75e7da26f383787f80ad76143b44844ffa28648fcc7099a83df1538c078d2f2"}, + {file = "pyarrow-15.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:3a99eac76ae14096c209850935057b9e8ce97a78397c5cde8724674774f34e5d"}, + {file = "pyarrow-15.0.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:dd532d3177e031e9b2d2df19fd003d0cc0520d1747659fcabbd4d9bb87de508c"}, + {file = "pyarrow-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ce8c89848fd37e5313fc2ce601483038ee5566db96ba0808d5883b2e2e55dc53"}, + {file = "pyarrow-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:862eac5e5f3b6477f7a92b2f27e560e1f4e5e9edfca9ea9da8a7478bb4abd5ce"}, + {file = "pyarrow-15.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f0ea3a29cd5cb99bf14c1c4533eceaa00ea8fb580950fb5a89a5c771a994a4e"}, + {file = "pyarrow-15.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb902f780cfd624b2e8fd8501fadab17618fdb548532620ef3d91312aaf0888a"}, + {file = "pyarrow-15.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:4f87757f02735a6bb4ad2e1b98279ac45d53b748d5baf52401516413007c6999"}, + {file = "pyarrow-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:efd3816c7fbfcbd406ac0f69873cebb052effd7cdc153ae5836d1b00845845d7"}, + {file = "pyarrow-15.0.1.tar.gz", hash = "sha256:21d812548d39d490e0c6928a7c663f37b96bf764034123d4b4ab4530ecc757a9"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + [[package]] name = "pycodestyle" version = "2.11.1" @@ -2237,39 +2300,39 @@ files = [ [[package]] name = "ruff" -version = "0.2.2" +version = "0.3.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0a9efb032855ffb3c21f6405751d5e147b0c6b631e3ca3f6b20f917572b97eb6"}, - {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d450b7fbff85913f866a5384d8912710936e2b96da74541c82c1b458472ddb39"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecd46e3106850a5c26aee114e562c329f9a1fbe9e4821b008c4404f64ff9ce73"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e22676a5b875bd72acd3d11d5fa9075d3a5f53b877fe7b4793e4673499318ba"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1695700d1e25a99d28f7a1636d85bafcc5030bba9d0578c0781ba1790dbcf51c"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b0c232af3d0bd8f521806223723456ffebf8e323bd1e4e82b0befb20ba18388e"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f63d96494eeec2fc70d909393bcd76c69f35334cdbd9e20d089fb3f0640216ca"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a61ea0ff048e06de273b2e45bd72629f470f5da8f71daf09fe481278b175001"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1439c8f407e4f356470e54cdecdca1bd5439a0673792dbe34a2b0a551a2fe3"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:940de32dc8853eba0f67f7198b3e79bc6ba95c2edbfdfac2144c8235114d6726"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c126da55c38dd917621552ab430213bdb3273bb10ddb67bc4b761989210eb6e"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3b65494f7e4bed2e74110dac1f0d17dc8e1f42faaa784e7c58a98e335ec83d7e"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1ec49be4fe6ddac0503833f3ed8930528e26d1e60ad35c2446da372d16651ce9"}, - {file = "ruff-0.2.2-py3-none-win32.whl", hash = "sha256:d920499b576f6c68295bc04e7b17b6544d9d05f196bb3aac4358792ef6f34325"}, - {file = "ruff-0.2.2-py3-none-win_amd64.whl", hash = "sha256:cc9a91ae137d687f43a44c900e5d95e9617cb37d4c989e462980ba27039d239d"}, - {file = "ruff-0.2.2-py3-none-win_arm64.whl", hash = "sha256:c9d15fc41e6054bfc7200478720570078f0b41c9ae4f010bcc16bd6f4d1aacdd"}, - {file = "ruff-0.2.2.tar.gz", hash = "sha256:e62ed7f36b3068a30ba39193a14274cd706bc486fad521276458022f7bccb31d"}, + {file = "ruff-0.3.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77f2612752e25f730da7421ca5e3147b213dca4f9a0f7e0b534e9562c5441f01"}, + {file = "ruff-0.3.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9966b964b2dd1107797be9ca7195002b874424d1d5472097701ae8f43eadef5d"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b83d17ff166aa0659d1e1deaf9f2f14cbe387293a906de09bc4860717eb2e2da"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb875c6cc87b3703aeda85f01c9aebdce3d217aeaca3c2e52e38077383f7268a"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be75e468a6a86426430373d81c041b7605137a28f7014a72d2fc749e47f572aa"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:967978ac2d4506255e2f52afe70dda023fc602b283e97685c8447d036863a302"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1231eacd4510f73222940727ac927bc5d07667a86b0cbe822024dd00343e77e9"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c6d613b19e9a8021be2ee1d0e27710208d1603b56f47203d0abbde906929a9b"}, + {file = "ruff-0.3.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8439338a6303585d27b66b4626cbde89bb3e50fa3cae86ce52c1db7449330a7"}, + {file = "ruff-0.3.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:de8b480d8379620cbb5ea466a9e53bb467d2fb07c7eca54a4aa8576483c35d36"}, + {file = "ruff-0.3.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b74c3de9103bd35df2bb05d8b2899bf2dbe4efda6474ea9681280648ec4d237d"}, + {file = "ruff-0.3.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f380be9fc15a99765c9cf316b40b9da1f6ad2ab9639e551703e581a5e6da6745"}, + {file = "ruff-0.3.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0ac06a3759c3ab9ef86bbeca665d31ad3aa9a4b1c17684aadb7e61c10baa0df4"}, + {file = "ruff-0.3.2-py3-none-win32.whl", hash = "sha256:9bd640a8f7dd07a0b6901fcebccedadeb1a705a50350fb86b4003b805c81385a"}, + {file = "ruff-0.3.2-py3-none-win_amd64.whl", hash = "sha256:0c1bdd9920cab5707c26c8b3bf33a064a4ca7842d91a99ec0634fec68f9f4037"}, + {file = "ruff-0.3.2-py3-none-win_arm64.whl", hash = "sha256:5f65103b1d76e0d600cabd577b04179ff592064eaa451a70a81085930e907d0b"}, + {file = "ruff-0.3.2.tar.gz", hash = "sha256:fa78ec9418eb1ca3db392811df3376b46471ae93792a81af2d1cbb0e5dcb5142"}, ] [[package]] name = "sentry-sdk" -version = "1.40.6" +version = "1.41.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.40.6.tar.gz", hash = "sha256:f143f3fb4bb57c90abef6e2ad06b5f6f02b2ca13e4060ec5c0549c7a9ccce3fa"}, - {file = "sentry_sdk-1.40.6-py2.py3-none-any.whl", hash = "sha256:becda09660df63e55f307570e9817c664392655a7328bbc414b507e9cb874c67"}, + {file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"}, + {file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"}, ] [package.dependencies] @@ -2295,7 +2358,7 @@ huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] -pure-eval = ["asttokens", "executing", "pure_eval"] +pure-eval = ["asttokens", "executing", "pure-eval"] pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] @@ -2550,6 +2613,20 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "tokenize-rt" version = "5.2.0" @@ -2928,4 +3005,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.11" -content-hash = "bf2f52893fd2ea7968ee7af64e8e557b47b9d7da465c1b3f21c19ff4868df9de" +content-hash = "d53220a87080d84cb8b378a43818ed754d0893350efb01ae7d5fb17f6a4a9fb6" diff --git a/pretrain.py b/pretrain.py index 3295683..9ae7140 100644 --- a/pretrain.py +++ b/pretrain.py @@ -1,9 +1,10 @@ +"""Train the model.""" + import argparse import os -from os.path import join +import sys +from typing import Any, Dict -import numpy as np -import pandas as pd import pytorch_lightning as pl import torch from lightning.pytorch.loggers import WandbLogger @@ -12,47 +13,34 @@ from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader -from models.cehr_bert.data import PretrainDataset +from lib.data import PretrainDataset +from lib.tokenizer import ConceptTokenizer +from lib.utils import ( + get_latest_checkpoint, + get_run_id, + load_config, + load_pretrain_data, + seed_everything, +) +from models.big_bird_cehr.model import BigBirdPretrain from models.cehr_bert.model import BertPretrain -from models.cehr_bert.tokenizer import ConceptTokenizer - -def main(args): - torch.manual_seed(args.seed) - np.random.seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - pl.seed_everything(args.seed) +def main(args: Dict[str, Any], model_config: Dict[str, Any]) -> None: + """Train the model.""" + seed_everything(args.seed) os.environ["CUDA_LAUNCH_BLOCKING"] = "1" torch.cuda.empty_cache() torch.set_float32_matmul_precision("medium") - # if not args.resume: - # data = pd.read_parquet(join(args.data_dir, "patient_sequences.parquet")) - - # data["label"] = ( - # (data["death_after_end"] >= 0) & (data["death_after_end"] < 365) - # ).astype(int) - # neg_pos_data = data[(data["deceased"] == 0) | (data["label"] == 1)] - - # pre_data = data[~data.index.isin(neg_pos_data.index)] - - # pre_df, fine_df = train_test_split( - # neg_pos_data, - # test_size=args.finetune_size, - # random_state=args.seed, - # stratify=neg_pos_data["label"], - # ) - - # pre_data = pd.concat([pre_data, pre_df]) - - # fine_df.to_parquet(join(args.data_dir, "fine_tune.parquet")) - # pre_data.to_parquet(join(args.data_dir, "pretrain.parquet")) - # else: - pre_data = pd.read_parquet(join(args.data_dir, "pretrain.parquet")) + pre_data = load_pretrain_data( + args.data_dir, + args.sequence_file, + args.id_file, + ) + pre_data.rename(columns={args.label_name: "label"}, inplace=True) + # Split data pre_train, pre_val = train_test_split( pre_data, test_size=args.val_size, @@ -60,15 +48,18 @@ def main(args): stratify=pre_data["label"], ) - tokenizer = ConceptTokenizer(data_dir=args.data_dir) + # Train Tokenizer + tokenizer = ConceptTokenizer(data_dir=args.vocab_dir) tokenizer.fit_on_vocab() + # Load datasets train_dataset = PretrainDataset( data=pre_train, tokenizer=tokenizer, max_len=args.max_len, mask_prob=args.mask_prob, ) + args.dataset_len = len(train_dataset) val_dataset = PretrainDataset( data=pre_val, @@ -81,14 +72,17 @@ def main(args): train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, + persistent_workers=args.persistent_workers, shuffle=True, - pin_memory=True, + pin_memory=args.pin_memory, ) + val_loader = DataLoader( val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - pin_memory=True, + persistent_workers=args.persistent_workers, + pin_memory=args.pin_memory, ) callbacks = [ @@ -103,113 +97,147 @@ def main(args): ), LearningRateMonitor(logging_interval="step"), ] + + # Create model + if args.model_type == "cehr_bert": + model = BertPretrain( + args=args, + vocab_size=tokenizer.get_vocab_size(), + padding_idx=tokenizer.get_pad_token_id(), + **model_config, + ) + elif args.model_type == "cehr_bigbird": + model = BigBirdPretrain( + args=args, + vocab_size=tokenizer.get_vocab_size(), + padding_idx=tokenizer.get_pad_token_id(), + **model_config, + ) + + latest_checkpoint = get_latest_checkpoint(args.checkpoint_dir) + + run_id = get_run_id(args.checkpoint_dir, retrieve=(latest_checkpoint is not None)) + wandb_logger = WandbLogger( - project="pretrain", + project=args.exp_name, save_dir=args.log_dir, + entity=args.workspace_name, + id=run_id, + resume="allow", ) + + # Setup PyTorchLightning trainer trainer = pl.Trainer( accelerator="gpu", devices=args.gpus, strategy=DDPStrategy(find_unused_parameters=True) if args.gpus > 1 else "auto", - precision=16, + precision="16-mixed", check_val_every_n_epoch=1, max_epochs=args.max_epochs, callbacks=callbacks, + deterministic=False, + enable_checkpointing=True, + enable_progress_bar=True, + enable_model_summary=True, logger=wandb_logger, - resume_from_checkpoint=args.checkpoint_path if args.resume else None, log_every_n_steps=args.log_every_n_steps, + accumulate_grad_batches=args.acc, + gradient_clip_val=1.0, ) - model = BertPretrain( - vocab_size=tokenizer.get_vocab_size(), - padding_idx=tokenizer.get_pad_token_id(), - ) - + # Train the model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=latest_checkpoint if latest_checkpoint else None, ) if __name__ == "__main__": parser = argparse.ArgumentParser() + # project configuration parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for reproducibility", + "--model-type", + type=str, + required=True, + help="Model type: 'cehr_bert' or 'cehr_bigbird'", ) parser.add_argument( - "--resume", - action="store_true", - help="Flag to resume training from a checkpoint", + "--exp-name", + type=str, + required=True, + help="Path to model config file", ) parser.add_argument( - "--data_dir", + "--label-name", type=str, - default="data_files", - help="Path to the data directory", + required=True, + help="Name of the label column", ) parser.add_argument( - "--finetune_size", - type=float, - default=0.1, - help="Finetune dataset size for splitting the data", + "--workspace-name", + type=str, + default=None, + help="Name of the Wandb workspace", ) parser.add_argument( - "--val_size", - type=float, - default=0.1, - help="Validation set size for splitting the data", + "--config-dir", + type=str, + default="models/configs", + help="Path to model config file", ) + + # data-related arguments parser.add_argument( - "--max_len", - type=int, - default=512, - help="Maximum length of the sequence", + "--data-dir", + type=str, + default="data_files", + help="Path to the data directory", ) parser.add_argument( - "--mask_prob", - type=float, - default=0.15, - help="Probability of masking the token", + "--sequence-file", + type=str, + default="patient_sequences_2048_labeled.parquet", + help="Path to the patient sequence file", ) parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size for training", + "--id-file", + type=str, + default="dataset_2048_mortality_1month.pkl", + help="Path to the patient id file", ) parser.add_argument( - "--num_workers", - type=int, - default=4, - help="Number of workers for training", + "--vocab-dir", + type=str, + default="data_files/vocab", + help="Path to the vocabulary directory of json files", + ) + parser.add_argument( + "--val-size", + type=float, + default=0.1, + help="Validation set size for splitting the data", ) + + # checkpointing and loggig arguments parser.add_argument( - "--checkpoint_dir", + "--checkpoint-dir", type=str, - default="checkpoints/pretraining", - help="Path to the training checkpoint", + default="checkpoints", + help="Path to the checkpoint directory", ) parser.add_argument( - "--log_dir", + "--log-dir", type=str, default="logs", help="Path to the log directory", ) parser.add_argument( - "--gpus", - type=int, - default=2, - help="Number of gpus for training", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=50, - help="Number of epochs for training", + "--checkpoint-path", + type=str, + default=None, + help="Checkpoint to resume training from", ) parser.add_argument( "--log_every_n_steps", @@ -217,12 +245,33 @@ def main(args): default=10, help="Number of steps to log the training", ) + + # Other arguments parser.add_argument( - "--checkpoint_path", - type=str, - default=None, - help="Checkpoint to resume training from", + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", ) args = parser.parse_args() - main(args) + + if args.model_type not in ["cehr_bert", "cehr_bigbird"]: + print("Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird'.") + sys.exit(1) + + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.exp_name) + os.makedirs(args.checkpoint_dir, exist_ok=True) + os.makedirs(args.log_dir, exist_ok=True) + + config = load_config(args.config_dir, args.model_type) + + train_config = config["train"] + for key, value in train_config.items(): + if not hasattr(args, key) or getattr(args, key) is None: + setattr(args, key, value) + + model_config = config["model"] + args.max_len = model_config["max_seq_length"] + + main(args, model_config) diff --git a/pretrain_bigbird.py b/pretrain_bigbird.py deleted file mode 100644 index 29e9661..0000000 --- a/pretrain_bigbird.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -File: pretrain_bigbird.py. - -Pretrain a bigbird model on MIMIC-IV FHIR data using Masked Language Modeling objective. -""" - -import argparse -import glob -import os -import pickle -from os.path import join -from typing import Any, Dict - -import numpy as np -import pandas as pd -import pytorch_lightning as pl -import torch -from lightning.pytorch.loggers import WandbLogger -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.strategies.ddp import DDPStrategy -from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader - -from models.big_bird_cehr.data import PretrainDataset -from models.big_bird_cehr.model import BigBirdPretrain -from models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer - - -def seed_everything(seed: int) -> None: - """Seed all components of the model.""" - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - pl.seed_everything(seed) - - -def get_latest_checkpoint(checkpoint_dir: str) -> Any: - """Return the most recent checkpointed file to resume training from.""" - list_of_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt")) - return max(list_of_files, key=os.path.getctime) if list_of_files else None - - -def main(args: Dict[str, Any]) -> None: - """Train the model.""" - # Setup environment - seed_everything(args.seed) - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - torch.cuda.empty_cache() - torch.set_float32_matmul_precision("medium") - - # Load data - data = pd.read_parquet( - join(args.data_dir, "patient_sequences_2048_labeled.parquet"), - ) - patient_ids = pickle.load( - open(join(args.data_dir, "dataset_2048_mortality_1month.pkl"), "rb"), - ) - pre_data = data.loc[data["patient_id"].isin(patient_ids["pretrain"])] - - # Split data - pre_train, pre_val = train_test_split( - pre_data, - test_size=args.val_size, - random_state=args.seed, - stratify=pre_data["label_mortality_1month"], - ) - - # Train Tokenizer - tokenizer = HuggingFaceConceptTokenizer(data_dir=args.vocab_dir) - tokenizer.fit_on_vocab() - - # Load datasets - train_dataset = PretrainDataset( - data=pre_train, - tokenizer=tokenizer, - max_len=args.max_len, - mask_prob=args.mask_prob, - ) - - val_dataset = PretrainDataset( - data=pre_val, - tokenizer=tokenizer, - max_len=args.max_len, - mask_prob=args.mask_prob, - ) - - train_loader = DataLoader( - train_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=True, - shuffle=True, - pin_memory=True, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=True, - pin_memory=True, - ) - - # Setup model dependencies - callbacks = [ - ModelCheckpoint( - monitor="val_loss", - mode="min", - filename="best", - save_top_k=1, - save_last=True, - verbose=True, - dirpath=args.checkpoint_dir, - ), - LearningRateMonitor(logging_interval="step"), - ] - - wandb_logger = WandbLogger( - project="bigbird_pretrain_a100", - save_dir=args.log_dir, - ) - - # Load latest checkpoint to continue training - # latest_checkpoint = get_latest_checkpoint(args.checkpoint_path) - - # Setup PyTorchLightning trainer - trainer = pl.Trainer( - accelerator="gpu", - devices=args.gpus, - strategy=DDPStrategy(find_unused_parameters=True) if args.gpus > 1 else "auto", - precision="16-mixed", - check_val_every_n_epoch=1, - max_epochs=args.max_epochs, - callbacks=callbacks, - deterministic=False, - enable_checkpointing=True, - enable_progress_bar=True, - enable_model_summary=True, - logger=wandb_logger, - # resume_from_checkpoint=latest_checkpoint if args.resume else None, - log_every_n_steps=args.log_every_n_steps, - accumulate_grad_batches=args.acc, - gradient_clip_val=1.0, - ) - - # Create BigBird model - model = BigBirdPretrain( - args=args, - dataset_len=len(train_dataset), - vocab_size=tokenizer.get_vocab_size(), - padding_idx=tokenizer.get_pad_token_id(), - ) - - # Train the model - trainer.fit( - model=model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--seed", - type=int, - default=42, - help="Random seed for reproducibility", - ) - parser.add_argument( - "--resume", - action="store_true", - default=False, - help="Flag to resume training from a checkpoint", - ) - parser.add_argument( - "--data_dir", - type=str, - default="/h/afallah/odyssey/odyssey/data/bigbird_data", - help="Path to the data directory", - ) - parser.add_argument( - "--vocab_dir", - type=str, - default="/h/afallah/odyssey/odyssey/data/vocab", - help="Path to the vocabulary directory of json files", - ) - parser.add_argument( - "--finetune_size", - type=float, - default=0.1, - help="Finetune dataset size for splitting the data", - ) - parser.add_argument( - "--val_size", - type=float, - default=0.1, - help="Validation set size for splitting the data", - ) - parser.add_argument( - "--max_len", - type=int, - default=2048, - help="Maximum length of the sequence", - ) - parser.add_argument( - "--mask_prob", - type=float, - default=0.15, - help="Probability of masking the token", - ) - parser.add_argument( - "--batch_size", - type=int, - default=12, - help="Batch size for training", - ) - parser.add_argument( - "--num_workers", - type=int, - default=4, - help="Number of workers for training", - ) - parser.add_argument( - "--checkpoint_dir", - type=str, - default="checkpoints/bigbird_pretraining_a100", - help="Path to the training checkpoint", - ) - parser.add_argument( - "--log_dir", - type=str, - default="logs", - help="Path to the log directory", - ) - parser.add_argument( - "--gpus", - type=int, - default=4, - help="Number of gpus for training", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=10, - help="Number of epochs for training", - ) - parser.add_argument( - "--acc", - type=int, - default=1, - help="Gradient accumulation", - ) - parser.add_argument( - "--log_every_n_steps", - type=int, - default=10, - help="Number of steps to log the training", - ) - parser.add_argument( - "--checkpoint_path", - type=str, - default=None, - help="Checkpoint to resume training from", - ) - parser.add_argument( - "--output_model_path", - type=str, - default="/h/afallah/odyssey/odyssey/bigbird_pretrained_2048.pt", - help="Directory to save the model", - ) - - args = parser.parse_args() - main(args) diff --git a/pyproject.toml b/pyproject.toml index a0fe3a6..8f4d8d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ pandas = "^2.2.1" sqlalchemy = "^2.0.28" psycopg2 = "^2.9.9" fhir-resources = "^5.1.1" +pyarrow = "^15.0.1" +plotly = "^5.7.0" [tool.poetry.group.test] optional = true @@ -31,7 +33,7 @@ pytest-cov = "^3.0.0" codecov = "^2.1.13" nbstripout = "^0.6.1" mypy = "^1.7.0" -ruff = "^0.2.0" +ruff = "^0.3.0" nbqa = { version = "^1.7.0", extras = ["toolchain"] } [tool.mypy]