Skip to content

Commit

Permalink
Debug.
Browse files Browse the repository at this point in the history
  • Loading branch information
zmgong committed Dec 16, 2024
1 parent 9fb4e20 commit 74c3ab5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
14 changes: 10 additions & 4 deletions bioscanclip/model/dna_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@
from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model

device = "cuda" if torch.cuda.is_available() else "cpu"


def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5):
def remove_extra_pre_fix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("module."):
key = key[7:]
new_state_dict[key] = value
return new_state_dict

def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5, remove_extra_prefix=False):
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>", "<CLS>", "<UNK>"])
vocab.set_default_index(vocab["<UNK>"])
vocab_size = len(vocab)
configuration = BertConfig(vocab_size=vocab_size, output_hidden_states=True)
bert_model = BertForMaskedLM(configuration)
load_bert_model(bert_model, bioscan_bert_checkpoint)
load_bert_model(bert_model, bioscan_bert_checkpoint, remove_extra_prefix=remove_extra_prefix)
return bert_model.to(device)


Expand Down
4 changes: 3 additions & 1 deletion bioscanclip/model/simple_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,13 @@ def load_clip_model(args, device=None):
if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert":
barcode_BERT_ckpt = args.barcodeBERT_checkpoint
k = 5
remove_extra_prefix = False
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path
k=4
remove_extra_prefix = True
pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k, remove_extra_prefix=remove_extra_prefix)
if disable_lora:
dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
num_classes=args.model_config.output_dim, lora_layer=[])
Expand Down
8 changes: 6 additions & 2 deletions bioscanclip/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,14 @@ def remove_extra_pre_fix(state_dict):
return new_state_dict


def load_bert_model(bert_model, path_to_ckpt):
def load_bert_model(bert_model, path_to_ckpt, remove_extra_prefix=False):
state_dict = torch.load(path_to_ckpt, map_location=torch.device("cpu"))
state_dict = remove_extra_pre_fix(state_dict)
bert_model.load_state_dict(state_dict)
try:
bert_model.load_state_dict(state_dict)
except:
print(state_dict.keys())



def print_result(
Expand Down

0 comments on commit 74c3ab5

Please sign in to comment.