From 74c3ab55ad6cf0b92d911f6e3584b5eb7ca25ce9 Mon Sep 17 00:00:00 2001 From: zmgong Date: Mon, 16 Dec 2024 14:45:18 -0800 Subject: [PATCH] Debug. --- bioscanclip/model/dna_encoder.py | 14 ++++++++++---- bioscanclip/model/simple_clip.py | 4 +++- bioscanclip/util/util.py | 8 ++++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/bioscanclip/model/dna_encoder.py b/bioscanclip/model/dna_encoder.py index 039bfc7..04a9386 100644 --- a/bioscanclip/model/dna_encoder.py +++ b/bioscanclip/model/dna_encoder.py @@ -9,16 +9,22 @@ from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model device = "cuda" if torch.cuda.is_available() else "cpu" - - -def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5): +def remove_extra_pre_fix(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("module."): + key = key[7:] + new_state_dict[key] = value + return new_state_dict + +def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5, remove_extra_prefix=False): kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k)) vocab = build_vocab_from_iterator(kmer_iter, specials=["", "", ""]) vocab.set_default_index(vocab[""]) vocab_size = len(vocab) configuration = BertConfig(vocab_size=vocab_size, output_hidden_states=True) bert_model = BertForMaskedLM(configuration) - load_bert_model(bert_model, bioscan_bert_checkpoint) + load_bert_model(bert_model, bioscan_bert_checkpoint, remove_extra_prefix=remove_extra_prefix) return bert_model.to(device) diff --git a/bioscanclip/model/simple_clip.py b/bioscanclip/model/simple_clip.py index 8a81150..0cfccff 100644 --- a/bioscanclip/model/simple_clip.py +++ b/bioscanclip/model/simple_clip.py @@ -263,11 +263,13 @@ def load_clip_model(args, device=None): if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert": barcode_BERT_ckpt = args.barcodeBERT_checkpoint k = 5 + remove_extra_prefix = False if hasattr(args.model_config, "barcodeBERT_ckpt_path"): barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path k=4 + remove_extra_prefix = True pre_trained_barcode_bert = load_pre_trained_bioscan_bert( - bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k) + bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k, remove_extra_prefix=remove_extra_prefix) if disable_lora: dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4, num_classes=args.model_config.output_dim, lora_layer=[]) diff --git a/bioscanclip/util/util.py b/bioscanclip/util/util.py index 2f6434b..702bb95 100644 --- a/bioscanclip/util/util.py +++ b/bioscanclip/util/util.py @@ -115,10 +115,14 @@ def remove_extra_pre_fix(state_dict): return new_state_dict -def load_bert_model(bert_model, path_to_ckpt): +def load_bert_model(bert_model, path_to_ckpt, remove_extra_prefix=False): state_dict = torch.load(path_to_ckpt, map_location=torch.device("cpu")) state_dict = remove_extra_pre_fix(state_dict) - bert_model.load_state_dict(state_dict) + try: + bert_model.load_state_dict(state_dict) + except: + print(state_dict.keys()) + def print_result(