Skip to content

Commit

Permalink
Update sequence pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
zmgong committed Dec 16, 2024
1 parent 71ea9ca commit 047e569
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 14 deletions.
38 changes: 36 additions & 2 deletions bioscanclip/model/dna_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch.nn as nn
from torch import Tensor
from torchtext.vocab import build_vocab_from_iterator
from transformers import BertConfig, BertForMaskedLM
from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model
from bioscanclip.util.util import PadSequence, KmerTokenizer, load_bert_model, KmerTokenizer_for_5m
from transformers import BertForMaskedLM, BertConfig, BertForTokenClassification
from torchtext.vocab import vocab as build_vocab_from_dict

device = "cuda" if torch.cuda.is_available() else "cpu"
def remove_extra_pre_fix(state_dict):
Expand All @@ -17,6 +18,14 @@ def remove_extra_pre_fix(state_dict):
new_state_dict[key] = value
return new_state_dict

def load_pre_trained_bioscan_bert_trained_with_5m(bioscan_bert_checkpoint, k=5):
ckpt = torch.load(bioscan_bert_checkpoint, map_location=device)
model_ckpt = remove_extra_pre_fix(ckpt["model"])
bert_config = BertConfig(**ckpt["bert_config"])
model = BertForMaskedLM(bert_config)
load_bert_model(model, model_ckpt)
return model.to(device)

def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5):
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>", "<CLS>", "<UNK>"])
Expand All @@ -28,6 +37,31 @@ def load_pre_trained_bioscan_bert(bioscan_bert_checkpoint, k=5):
return bert_model.to(device)


def get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M(k=4):
base_pairs = "ACGT"
special_tokens = ["[MASK]", "[UNK]"]
UNK_TOKEN = "[UNK]"
stride = k
max_len_of_token = 256

kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=k)]
kmer_dict = dict.fromkeys(kmers, 1)
vocab = build_vocab_from_dict(kmer_dict, specials=special_tokens)
vocab.set_default_index(vocab[UNK_TOKEN])
vocab_size = len(vocab)
tokenizer = KmerTokenizer_for_5m(
k, vocab, stride=stride, padding=True, max_len=max_len_of_token
)

max_len = 660
pad = PadSequence(max_len)

sequence_pipeline = lambda x: [0, *vocab(tokenizer(pad(x)))]
return sequence_pipeline




def get_sequence_pipeline(k=5):
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>", "<CLS>", "<UNK>"])
Expand Down
19 changes: 14 additions & 5 deletions bioscanclip/model/simple_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from bioscanclip.model.mlp import MLPEncoder
from bioscanclip.model.image_encoder import LoRA_ViT_timm, LoRA_ViT_OpenCLIP
from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder
from bioscanclip.model.dna_encoder import load_pre_trained_bioscan_bert, LoRA_barcode_bert, Freeze_DNA_Encoder, load_pre_trained_bioscan_bert_trained_with_5m
from bioscanclip.model.language_encoder import load_pre_trained_bert, LoRA_bert, LoRA_bert_OpenCLIP
from bioscanclip.util.util import add_lora_layer_to_open_clip
import numpy as np
Expand Down Expand Up @@ -261,13 +261,22 @@ def load_clip_model(args, device=None):
if hasattr(args.model_config, 'dna'):
if args.model_config.dna.input_type == "sequence":
if dna_model == "barcode_bert" or dna_model == "lora_barcode_bert":
barcode_BERT_ckpt = args.barcodeBERT_checkpoint
k = 5

if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
barcode_BERT_ckpt = args.model_config.barcodeBERT_ckpt_path
k=4
pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)
pre_trained_barcode_bert = load_pre_trained_bioscan_bert_trained_with_5m(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)


else:
barcode_BERT_ckpt = args.barcodeBERT_checkpoint
k = 5
pre_trained_barcode_bert = load_pre_trained_bioscan_bert(
bioscan_bert_checkpoint=barcode_BERT_ckpt, k=k)



if disable_lora:
dna_encoder = LoRA_barcode_bert(model=pre_trained_barcode_bert, r=4,
num_classes=args.model_config.output_dim, lora_layer=[])
Expand Down
33 changes: 26 additions & 7 deletions bioscanclip/util/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from bioscanclip.model.dna_encoder import get_sequence_pipeline
from bioscanclip.model.dna_encoder import get_sequence_pipeline, \
get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M
from torch.utils.data.distributed import DistributedSampler
import json
import time
Expand Down Expand Up @@ -460,7 +461,10 @@ def load_bioscan_dataloader_with_train_seen_and_separate_keys(args, world_size=N

return_language = True

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

train_seen_dataloader = construct_dataloader(
args,
Expand Down Expand Up @@ -548,7 +552,10 @@ def load_dataloader_for_everything_in_5m(args, world_size=None, rank=None):

return_language = True

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

pre_train_dataloader = construct_dataloader(
args,
Expand Down Expand Up @@ -642,7 +649,10 @@ def load_dataloader(args, world_size=None, rank=None, for_pretrain=True):

return_language = True

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

seen_val_dataloader = construct_dataloader(
args,
Expand Down Expand Up @@ -731,7 +741,10 @@ def load_bioscan_dataloader_all_small_splits(args, world_size=None, rank=None):

return_language = True

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

if hasattr(args.model_config, 'dataset') and args.model_config.dataset == "bioscan_5m":
train_seen_dataloader = construct_dataloader(
Expand Down Expand Up @@ -1062,7 +1075,10 @@ def load_insect_dataloader_trainval(args, num_workers=8, shuffle_for_train_seen_
with open(filename, 'r') as file:
specie_to_other_labels = json.load(file)

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

trainval_dataset = INSECTDataset(
args.insect_data.path_to_att_splits_mat, args.insect_data.path_to_res_101_mat,
Expand All @@ -1084,7 +1100,10 @@ def load_insect_dataloader(args, world_size=None, rank=None, num_workers=8, load
with open(filename, 'r') as file:
specie_to_other_labels = json.load(file)

sequence_pipeline = get_sequence_pipeline()
if hasattr(args.model_config, "barcodeBERT_ckpt_path"):
sequence_pipeline = get_sequence_pipeline_for_barcodeBERT_pre_trained_with_5M()
else:
sequence_pipeline = get_sequence_pipeline()

if load_all_in_one:
all_dataset = INSECTDataset(
Expand Down
27 changes: 27 additions & 0 deletions bioscanclip/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ def __call__(self, dna_sequence):
tokens.append(k_mer)
return tokens

class KmerTokenizer_for_5m(object):
def __init__(self, k, vocabulary_mapper, stride=1, padding=False, max_len=660):
self.k = k
self.stride = stride
self.padding = padding
self.max_len = max_len
self.vocabulary_mapper = vocabulary_mapper

def __call__(self, dna_sequence, offset=0):
tokens = []
att_mask = [1] * (self.max_len // self.stride)
x = dna_sequence[offset:]
if self.padding:
if len(x) > self.max_len:
x = x[: self.max_len]
else:
att_mask[len(x) // self.stride:] = [0] * (len(att_mask) - len(x) // self.stride)
x = x + "N" * (self.max_len - len(x))
for i in range(0, len(x) - self.k + 1, self.stride):
k_mer = x[i: i + self.k]
tokens.append(k_mer)

tokens = torch.tensor(self.vocabulary_mapper(tokens), dtype=torch.int64)
att_mask = torch.tensor(att_mask, dtype=torch.int32)

return tokens, att_mask


def set_seed(seed=None):
if seed is None:
Expand Down

0 comments on commit 047e569

Please sign in to comment.