From 047e569e7ca20e6eb4ef72fde04a1177c3a5221b Mon Sep 17 00:00:00 2001 From: zmgong Date: Mon, 16 Dec 2024 15:29:52 -0800 Subject: [PATCH] Update sequence pipeline --- bioscanclip/model/dna_encoder.py | 38 ++++++++++++++++++++++++++++++-- bioscanclip/model/simple_clip.py | 19 +++++++++++----- bioscanclip/util/dataset.py | 33 +++++++++++++++++++++------ bioscanclip/util/util.py | 27 +++++++++++++++++++++++ 4 files changed, 103 insertions(+), 14 deletions(-) diff --git a/bioscanclip/model/dna_encoder.py b/bioscanclip/model/dna_encoder.py index 0b1e4d7..a41b7cc 100644 --- a/bioscanclip/model/dna_encoder.py +++ b/bioscanclip/model/dna_encoder.py @@ -5,8 +5,9 @@ import torch.nn as nn from torch import Tensor from torchtext.vocab import build_vocab_from_iterator -from transformers import BertConfig, BertForMaskedLM -from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model +from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model, KmerTokenizer_for_5m +from transformers import BertForMaskedLM, BertConfig, BertForTokenClassification +from torchtext.vocab import vocab as build_vocab_from_dict device = "cuda" if torch.cuda.is_available() else "cpu" def remove_extra_pre_fix(state_dict): @@ -17,6 +18,14 @@ def remove_extra_pre_fix(state_dict): new_state_dict[key] = value return new_state_dict +def load_pre_trained_bioscan_bert_trained_with_5m(bioscan_bert_checkpoint, k=5): + ckpt = torch.load(bioscan_bert_checkpoint, map_location=device) + model_ckpt = remove_extra_pre_fix(ckpt["model"]) + bert_config = BertConfig(**ckpt["bert_config"]) + model = BertForMaskedLM(bert_config) + load_bert_model(model, model_ckpt) + return model.to(device) + def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5): kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k)) vocab = build_vocab_from_iterator(kmer_iter, specials=["", "", ""]) @@ -28,6 +37,31 @@ def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5): return bert_model.to(device) +def get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M(k=4): + base_pairs = "ACGT" + special_tokens = ["[MASK]", "[UNK]"] + UNK_TOKEN = "[UNK]" + stride = k + max_len_of_token = 256 + + kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=k)] + kmer_dict = dict.fromkeys(kmers, 1) + vocab = build_vocab_from_dict(kmer_dict, specials=special_tokens) + vocab.set_default_index(vocab[UNK_TOKEN]) + vocab_size = len(vocab) + tokenizer = KmerTokenizer_for_5m( + k, vocab, stride=stride, padding=True, max_len=max_len_of_token + ) + + max_len = 660 + pad = PadSequence(max_len) + + sequence_pipeline = lambda x: [0, *vocab(tokenizer(pad(x)))] + return sequence_pipeline + + + + def get_sequence_pipeline(k=5): kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k)) vocab = build_vocab_from_iterator(kmer_iter, specials=["", "", ""]) diff --git a/bioscanclip/model/simple_clip.py b/bioscanclip/model/simple_clip.py index 8a81150..2300b0e 100644 --- a/bioscanclip/model/simple_clip.py +++ b/bioscanclip/model/simple_clip.py @@ -3,7 +3,7 @@ import torch.nn as nn from bioscanclip.model.mlp import MLPEncoder from bioscanclip.model.image_encoder import LoRA_ViT_timm, LoRA_ViT_OpenCLIP -from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder +from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder, load_pre_trained_bioscan_bert_trained_with_5m from bioscanclip.model.language_encoder import load_pre_trained_bert, LoRA_bert, LoRA_bert_OpenCLIP from bioscanclip.util.util import add_lora_layer_to_open_clip import numpy as np @@ -261,13 +261,22 @@ def load_clip_model(args, device=None): if hasattr(args.model_config, 'dna'): if args.model_config.dna.input_type == "sequence": if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert": - barcode_BERT_ckpt = args.barcodeBERT_checkpoint - k = 5 + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path k=4 - pre_trained_barcode_bert = load_pre_trained_bioscan_bert( - bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k) + pre_trained_barcode_bert = load_pre_trained_bioscan_bert_trained_with_5m( + bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k) + + + else: + barcode_BERT_ckpt = args.barcodeBERT_checkpoint + k = 5 + pre_trained_barcode_bert = load_pre_trained_bioscan_bert( + bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k) + + + if disable_lora: dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4, num_classes=args.model_config.output_dim, lora_layer=[]) diff --git a/bioscanclip/util/dataset.py b/bioscanclip/util/dataset.py index aa0e6e2..07d4c19 100644 --- a/bioscanclip/util/dataset.py +++ b/bioscanclip/util/dataset.py @@ -10,7 +10,8 @@ from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms -from bioscanclip.model.dna_encoder import get_sequence_pipeline +from bioscanclip.model.dna_encoder import get_sequence_pipeline, \ + get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M from torch.utils.data.distributed import DistributedSampler import json import time @@ -460,7 +461,10 @@ def load_bioscan_dataloader_with_train_seen_and_separate_keys(args, world_size=N return_language = True - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() train_seen_dataloader = construct_dataloader( args, @@ -548,7 +552,10 @@ def load_dataloader_for_everything_in_5m(args, world_size=None, rank=None): return_language = True - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() pre_train_dataloader = construct_dataloader( args, @@ -642,7 +649,10 @@ def load_dataloader(args, world_size=None, rank=None, for_pretrain=True): return_language = True - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() seen_val_dataloader = construct_dataloader( args, @@ -731,7 +741,10 @@ def load_bioscan_dataloader_all_small_splits(args, world_size=None, rank=None): return_language = True - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() if hasattr(args.model_config, 'dataset') and args.model_config.dataset == "bioscan_5m": train_seen_dataloader = construct_dataloader( @@ -1062,7 +1075,10 @@ def load_insect_dataloader_trainval(args, num_workers=8, shuffle_for_train_seen_ with open(filename, 'r') as file: specie_to_other_labels = json.load(file) - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() trainval_dataset = INSECTDataset( args.insect_data.path_to_att_splits_mat, args.insect_data.path_to_res_101_mat, @@ -1084,7 +1100,10 @@ def load_insect_dataloader(args, world_size=None, rank=None, num_workers=8, load with open(filename, 'r') as file: specie_to_other_labels = json.load(file) - sequence_pipeline = get_sequence_pipeline() + if hasattr(args.model_config, "barcodeBERT_ckpt_path"): + sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M() + else: + sequence_pipeline = get_sequence_pipeline() if load_all_in_one: all_dataset = INSECTDataset( diff --git a/bioscanclip/util/util.py b/bioscanclip/util/util.py index c35aa1d..50af776 100644 --- a/bioscanclip/util/util.py +++ b/bioscanclip/util/util.py @@ -93,6 +93,33 @@ def __call__(self, dna_sequence): tokens.append(k_mer) return tokens +class KmerTokenizer_for_5m(object): + def __init__(self, k, vocabulary_mapper, stride=1, padding=False, max_len=660): + self.k = k + self.stride = stride + self.padding = padding + self.max_len = max_len + self.vocabulary_mapper = vocabulary_mapper + + def __call__(self, dna_sequence, offset=0): + tokens = [] + att_mask = [1] * (self.max_len // self.stride) + x = dna_sequence[offset:] + if self.padding: + if len(x) > self.max_len: + x = x[: self.max_len] + else: + att_mask[len(x) // self.stride:] = [0] * (len(att_mask) - len(x) // self.stride) + x = x + "N" * (self.max_len - len(x)) + for i in range(0, len(x) - self.k + 1, self.stride): + k_mer = x[i: i + self.k] + tokens.append(k_mer) + + tokens = torch.tensor(self.vocabulary_mapper(tokens), dtype=torch.int64) + att_mask = torch.tensor(att_mask, dtype=torch.int32) + + return tokens, att_mask + def set_seed(seed=None): if seed is None: