From b317b2c8b2eceec36b1e6f111fb201a9d9f843e1 Mon Sep 17 00:00:00 2001 From: Michael Pieler <36303596+MicPie@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:51:12 +0200 Subject: [PATCH] Class balanced sampling (#524) Co-authored-by: Michael Pieler <36303596+MicPie@users.noreply.github.com> Co-authored-by: Michael Pieler Co-authored-by: Kevin Maik Jablonka Co-authored-by: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> --- data/text_sampling/text_sampling.py | 488 +++++++++++++++++++--------- 1 file changed, 329 insertions(+), 159 deletions(-) diff --git a/data/text_sampling/text_sampling.py b/data/text_sampling/text_sampling.py index 94dadbe33..0f2a5e522 100644 --- a/data/text_sampling/text_sampling.py +++ b/data/text_sampling/text_sampling.py @@ -542,6 +542,7 @@ def check_targets_and_identifiers(meta: dict, df: pd.DataFrame): if "split" not in df.columns: df["split"] = "train" self.df = df + self.df_orig = None # only used for class_balanced sampling to keep a copy of the original self.df # text templates self.benchmarking_templates = benchmarking_templates @@ -906,12 +907,59 @@ def sample(self, sample: pd.Series, template_idx: int = None): def __getitem__(self, sample_idx: int, template_idx: int = None): """Get item from data with sample and template index. - A random template will be ised if no template index is handed over.""" + A random template will be used if no template index is handed over.""" sample = self.df.iloc[sample_idx] return self.sample(sample, template_idx) - def apply_sampling(self, template_idx: int = None): + def apply_sampling( + self, template_idx: int = None, class_balanced: bool = True + ): # TODO: set class_balanced to False !!! """Applies the sampling to the entire data frame.""" + if template_idx is not None and class_balanced is True: + # create a copy of the original self.df to restore self.df after class balanced sampling + if self.df_orig is None: + self.df_orig = self.df.copy() + + # get targets for balancing + template = self.get_prompt_template_from_template_idx(template_idx) + target_to_balance = [] + for target in self.meta["targets"]: + for var in template.input_variables: + if (target["id"] in var.replace("#", "")) or ( + target["id"] in var.replace("%", "") + ): + # print(f"{target['id']=}") + target_to_balance.append(target["id"]) + target_to_balance = list(set(target_to_balance)) + + # create class balanced self.df + if len(target_to_balance) > 1: + print("TEMPLATE USES MORE THAN ONE TARGET!") + print(f"{target_to_balance=}") + target_to_balance = random.sample(target_to_balance, k=1)[0] + print(f"{target_to_balance=}") + else: + # unwrap list of length 1 + target_to_balance = target_to_balance[0] + df_vc = self.df_orig[target_to_balance].value_counts() + vc_min = df_vc.min() + vc_max = df_vc.max() + if vc_max > 1: + dfs = [] + # cycle through all values and get a sample of size vc_min + for values in df_vc.index.tolist(): + dfs.append( + self.df_orig[self.df_orig[target_to_balance] == values].sample( + vc_min + ) + ) + self.df = pd.concat(dfs) + else: + self.df = self.df_orig + print(self.df[target_to_balance].value_counts()) + # else: + # assert template_idx is None and class_balanced is True, "class_balanced sampling is only supported with template_idx." # noqa: E501 + self.df["sample"] = self.df.apply( lambda sample: self.sample(sample, template_idx), axis=1 ) @@ -1047,84 +1095,226 @@ def export(self, fn_suffix: str = None): return pd.DataFrame(print_data) def apply_sampling_and_export( - self, template_idx: int = None, fn_suffix: str = None + self, + template_idx: int = None, + fn_suffix: str = None, + class_balanced=True, ): """Applies the sampling and exports the data.""" - self.apply_sampling(template_idx=template_idx) + self.apply_sampling(template_idx=template_idx, class_balanced=class_balanced) df_results = self.export(fn_suffix=fn_suffix) + + # if class_balanced restore self.df to original df that is not balanced + if class_balanced: + self.df = self.df_orig + print(f"\n### results\n{df_results.to_string()}") if __name__ == "__main__": path_base = __file__.replace("text_sampling/text_sampling.py", "") - path_data_dir = sorted(glob.glob(path_base + "tabular/*")) - path_data_dir += sorted( - [p for p in glob.glob(path_base + "kg/*") if os.path.isdir(p)] - ) - path_lm_eval_data_dir = path_base + "text_sampling/export" - - # index = [i for i, x in enumerate(path_data_dir) if x.find("RedDB") != -1][0] + # path_data_dir = sorted(glob.glob(path_base + "tabular/*")) + # path_data_dir += sorted( + # [p for p in glob.glob(path_base + "kg/*") if os.path.isdir(p)] + # ) + # path_lm_eval_data_dir = path_base + "text_sampling/export_class_balanced" + # path_lm_eval_data_dir = path_base + "text_sampling/export_class_balanced_benchmark" + # path_lm_eval_data_dir = path_base + "text_sampling/export_standard" + path_lm_eval_data_dir = path_base + "text_sampling/export_standard_benchmark" + # path_lm_eval_data_dir = path_base + "text_sampling/export_inverse" + + path_data_dir = [ + # CLASS BALANCED DATASETS, set class_balanced = True !!! + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ames_mutagenicity", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/aminoacids", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/BACE", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/BBBP", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bc5chem", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bc5disease", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_10", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_11", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_12", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_13", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_14", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_15", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_16", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_17", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_18", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_19", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_2", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_20", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_22", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_23", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_24", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_25", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_26", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_27", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_28", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_29", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_3", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_30", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_31", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_33", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_34", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_35", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_36", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_37", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_38", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_39", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_4", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_40", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_47", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_48", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_5", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_52", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_57", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_58", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_6", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_7", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_8", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bio_ner_9", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bioavailability_ma_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/block_polymers_morphology", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/blood_brain_barrier_martins_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/caco2_wang", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/carcinogens", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cav3_t-type_calcium_channels_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/chemcaption_fragments", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/chemcaption_rdkit", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/chemdner", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/choline_transporter_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/clearance_astrazeneca", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/clintox", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp_p450_1a2_inhibition_veith_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp_p450_2c19_inhibition_veith_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp_p450_2c9_inhibition_veith_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp_p450_2d6_inhibition_veith_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp_p450_3a4_inhibition_veith_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp2c9_substrate_carbonmangels", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp2d6_substrate_carbonmangels", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/cyp3a4_substrate_carbonmangels", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/drug_induced_liver_injury", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/flashpoint", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/formation_energies", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/freesolv", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/half_life_obach", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/herg_blockers", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/herg_central_at_10uM", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/herg_central_at_1uM", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/herg_central_inhib", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/herg_karim_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/hiv", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/human_intestinal_absorption", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/kcnq2_potassium_channel_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ld50_zhu", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/lipophilicity", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/m1_muscarinic_receptor_agonists_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/m1_muscarinic_receptor_antagonists_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/melting_points", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mona", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mp_anisotropy", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mp_bulk_modulus", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mp_shear_modulus", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_466", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_548", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_600", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_644", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_652", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_689", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_692", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_712", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_713", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_733", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_737", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_810", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_832", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_846", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_852", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_858", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/MUV_859", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ncbi_disease", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_ahr_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_ar_lbd_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_ar_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_aromatase_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_er_lbd_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_er_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/nr_ppar_gamma_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ocp", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/opv", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/oqmd", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ord_rxn_smiles_yield_pred", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ord_steps_yield", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/orexin1_receptor_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/p_glycoprotein_inhibition_broccatelli_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/pampa_ncats", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/peptides_hemolytic", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/peptides_nonfouling", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/peptides_soluble", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/potassium_ion_channel_kir2_1_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/qm8", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/qm9", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/qmof_gcmc", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/qmof_quantum", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/rhea_db_predictions", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sarscov2_3clpro_diamond", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sarscov2_vitro_touret", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/serine_threonine_kinase_33_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/SIDER", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/skin_reaction", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/solubility_aqsoldb", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sr_are_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sr_atad5_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sr_hse_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sr_mmp_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/sr_p53_tox21", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/thermosol", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/tyrosyl-dna_phosphodiesterase_butkiewicz", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/uniprot_binding_sites_multiple", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/uniprot_organisms", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/volume_of_distribution_at_steady_state_lombardo_et_al", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bicerano_dataset", + # STANDARD SAMPLING DATASETS, set class_balanced = False !!! + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/RedDB", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mattermodeling_stackexchange", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mp_descriptions", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/mp_self_supervised", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/rdkit_features", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/smiles_to_3d", + "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/uniprot_binding_sites_multiple", + # INVERSE DATASETS, set class_balanced = False !!! + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/inverse_1", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/inverse_2", + # "/weka/proj-chemnlp/micpie/chemnlp/data/tabular/inverse_3", + ] + # index = [i for i, x in enumerate(path_data_dir) if x.find("qmof_quantum") != -1][0] # print(index) # path_data_dir = path_data_dir[index:] + # path_data_dir = path_data_dir[index + 1 :] # path_data_dir = [path_data_dir[index]] - # path_data_dir = [ - # '/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bioavailability_ma_et_al', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/RedDB', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/SIDER', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/aminoacids', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/bc5chem', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/bc5disease', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/bicerano_dataset', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/block_polymers_morphology', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/buchwald_hartwig', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/chem_caption_smarts', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/chemdner', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/chemistry_stackexchange', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/compound_chebi_chebi_chebi_1', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/compound_chebi_chebi_chebi_2', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/drug_chebi_chebi_chebi', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/fda_adverse_reactions', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/formation_energies', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/h2_storage_materials', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mattermodeling_stackexchange', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/melting_points', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mofdscribe', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mol2svg', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/moses', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mp_anisotropy', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mp_bulk_modulus', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mp_descriptions', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mp_self_supervised', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/mp_shear_modulus', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ncbi_disease', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/nomad_structure', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ocp', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/opv', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/oqmd', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_masked', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_predictions', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_procedure_steps', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_rxn_smiles_procedure', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_rxn_smiles_yield_pred', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/ord_steps_yield', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/perovskite_db', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/physics_stackexchange', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/qm8', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/qm9', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/suzuki_miyaura_sach', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/uspto', - # '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/uspto_yield', - # ] for path in path_data_dir: # subselect one path - # if path.find("data/tabular/") == -1: continue # if path.find("data/kg/") == -1: continue - # if path.find("chembl33") != -1: continue - # if path.find("data/kg/compound_chebi") == -1: continue - # if path.find("data/tabular/cyp3a4_substrate_carbonmangels") == -1: continue - # if path.find("data/tabular/bio_ner") == -1: continue - # if path.find("rdkit_features") != -1: continue + # if path.find("data/tabular/") == -1: + # continue + # if path.find("data/kg/") == -1: continue + + # exclude data_clean.csv files with more than 1GB + # if path.find("rdkit_features") != -1: + # continue + # if path.find("iupac_smiles") != -1: + # continue + # if path.find("orbnet_denali") != -1: + # continue + # if path.find("ord_masked") != -1: + # continue + # if path.find("ord_predictions") != -1: + # continue + # if path.find("chembl_v29") != -1: + # continue print(f"\n###### {path}") path_meta = path + "/meta.yaml" @@ -1179,33 +1369,8 @@ def apply_sampling_and_export( if "templates" in meta: multiple_choice_rnd_symbols = ["", ".", ".)", ")", ":", "()", "[]"] print(f"Running sampling for: {path}") - # uncomment to randomly sample from all templates and save the output to a single file - # TemplateSampler( - # path, - # path_lm_eval_data_dir, - # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - # additional_templates=additional_templates, - # benchmarking_templates=False, - # multiple_choice_benchmarking_templates=False, - # ).apply_sampling_and_export() - - # tempsamp = TemplateSampler( - # path, - # path_lm_eval_data_dir, - # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - # additional_templates=additional_templates, - # benchmarking_templates=False, - # multiple_choice_benchmarking_templates=False, - # ) - # for i, template in enumerate( - # [t for t in meta["templates"] if "" not in t] - # ): - # print(f"\nRunning sampling for template {i}:\n{template}") - # tempsamp.apply_sampling_and_export( - # template_idx=i, - # fn_suffix=i, - # ) + # CHUNKED TRAIN SAMPLING chunksize = 1_000_000 path_data_csv = path + "/data_clean.csv" with pd.read_csv( @@ -1213,74 +1378,79 @@ def apply_sampling_and_export( ) as reader: chunk_idx = 0 for df_chunk in reader: - tempsamp = TemplateSampler( - path, - df_chunk, - path_lm_eval_data_dir, - multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - additional_templates=additional_templates, - benchmarking_templates=False, - multiple_choice_benchmarking_templates=False, - ) - for i, template in enumerate( - [t for t in meta["templates"] if "" not in t] + # tempsamp = TemplateSampler( + # path, + # df_chunk, + # path_lm_eval_data_dir, + # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, + # additional_templates=additional_templates, + # benchmarking_templates=False, + # multiple_choice_benchmarking_templates=False, + # ) + # for i, template in enumerate( + # [t for t in meta["templates"] if "" not in t] + # ): + # print(f"\nRunning sampling for template {i}:\n{template}") + # tempsamp.apply_sampling_and_export( + # template_idx=i, + # fn_suffix=f"{chunk_idx}-{i}", + # ) + + # STANDARD BENCHMARKING SAMPLING + if any(["" in t for t in meta["templates"]]): + tempsamp = TemplateSampler( + path, + df_chunk, + path_lm_eval_data_dir, + multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, + additional_templates=additional_templates, + benchmarking_templates=True, + multiple_choice_benchmarking_templates=False, + ) + for i, template in enumerate( + [ + t + for t in meta["templates"] + if "" in t and "%multiple_choice_" not in t + ] + ): + print( + f"\nRunning sampling for template {i}:\n{template}" + ) + tempsamp.apply_sampling_and_export( + template_idx=i, + fn_suffix=f"{chunk_idx}-{i}", + ) + + # MULTIPLE CHOICE BENCHMARKING SAMPLING + if any( + [ + "" in t and "%multiple_choice_" in t + for t in meta["templates"] + ] ): - print(f"\nRunning sampling for template {i}:\n{template}") - tempsamp.apply_sampling_and_export( - template_idx=i, - fn_suffix=f"{chunk_idx}-{i}", + tempsamp = TemplateSampler( + path, + df_chunk, + path_lm_eval_data_dir, + multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, + additional_templates=additional_templates, + benchmarking_templates=True, + multiple_choice_benchmarking_templates=True, ) - chunk_idx += 1 + for i, template in enumerate( + [ + t + for t in meta["templates"] + if "" in t and "%multiple_choice_" in t + ] + ): + print( + f"\nRunning sampling for template {i}:\n{template}" + ) + tempsamp.apply_sampling_and_export( + template_idx=i, + fn_suffix=f"{chunk_idx}-{i}", + ) - # if any(["" in t for t in meta["templates"]]): - # # uncomment to randomly sample from all templates and save the output to a single file - # # TemplateSampler( - # # path, - # # path_lm_eval_data_dir, - # # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - # # additional_templates=additional_templates, - # # benchmarking_templates=True, - # # multiple_choice_benchmarking_templates=False, - # # ).apply_sampling_and_export() - - # tempsamp = TemplateSampler( - # path, - # path_lm_eval_data_dir, - # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - # additional_templates=additional_templates, - # benchmarking_templates=True, - # multiple_choice_benchmarking_templates=False, - # ) - # for i, template in enumerate( - # [ - # t - # for t in meta["templates"] - # if "" in t and "%multiple_choice_" not in t - # ] - # ): - # print(f"\nRunning sampling for template {i}:\n{template}") - # tempsamp.apply_sampling_and_export( - # template_idx=i, - # fn_suffix=i, - # ) - - # if any(["%multiple_choice_" in t for t in meta["templates"]]): - # TemplateSampler( - # path, - # path_lm_eval_data_dir, - # multiple_choice_rnd_symbols=multiple_choice_rnd_symbols, - # additional_templates=additional_templates, - # benchmarking_templates=True, - # multiple_choice_benchmarking_templates=True, - # ).apply_sampling_and_export() - - # # for i, s in enumerate(multiple_choice_rnd_symbols): - # # TemplateSampler( - # # path, - # # path_lm_eval_data_dir, - # # multiple_choice_rnd_symbols=[s], - # # additional_templates=additional_templates, - # # benchmarking_templates=True, - # # multiple_choice_benchmarking_templates=True, - # # multiple_choice_benchmarking_format=i, - # # ).apply_sampling_and_export() + chunk_idx += 1