diff --git a/data/unsplashlite.py b/data/unsplashlite.py index ef6381d..eb74ee7 100644 --- a/data/unsplashlite.py +++ b/data/unsplashlite.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset from torchvision import transforms -from model.CringeLDM import CringeBERTWrapper +from model.CringeBERT import CringeBERTWrapper from utils import * class UnsplashLiteDataset(Dataset): diff --git a/inference.py b/inference.py index 9b1d07d..34d8695 100644 --- a/inference.py +++ b/inference.py @@ -3,7 +3,7 @@ import pytorch_lightning as pl import torch -from model.CringeLDM import CringeDenoiserModel +from model.CringeDenoiser import CringeDenoiserModel from model.CringeVAE import CringeVAEModel from PIL import Image diff --git a/model/CringeBERT.py b/model/CringeBERT.py index 1035dfd..2a37f1d 100644 --- a/model/CringeBERT.py +++ b/model/CringeBERT.py @@ -6,5 +6,47 @@ from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data import random_split +from transformers.models.bert.modeling_bert import BertModel +from transformers.models.bert.tokenization_bert import BertTokenizer +from transformers.utils.generic import ModelOutput +from transformers.modeling_outputs import BaseModelOutput -from model.cringe.unet import UNet \ No newline at end of file + +from model.cringe.unet import UNet + +class CringeBERTWrapper: + """ + BERT Wrapper + + This is a wrapper for the BERT model. Ideally would be trained from the same + dataset as the LDM model, but for now we just use the pretrained BERT model. + """ + + def loadModel(self, cpu): + self.bert_model = BertModel.from_pretrained( + 'bert-base-uncased') # type: ignore + if torch.cuda.is_available() & (not cpu): + self.bert_model = self.bert_model.cuda() # type: ignore + self.bert_tokenizer = BertTokenizer.from_pretrained( + 'bert-base-uncased') # type: ignore + + def __init__(self, cpu=False): + self.loadModel(cpu) + pass + + def model_output(self, input_ids: torch.Tensor): + with torch.no_grad(): + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + output = self.bert_model(input_ids) # type: ignore + q = output.last_hidden_state + return q.unsqueeze(0) + + def inference(self, query): + with torch.no_grad(): + # Encode the text using BERT + input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \ + .unsqueeze(0) # Add batch dimension + # Normalise so that all values are between 0 and 1 + input_ids = (input_ids + 1) / 2 + return self.model_output(input_ids) \ No newline at end of file diff --git a/model/CringeLDM.py b/model/CringeDenoiser.py similarity index 74% rename from model/CringeLDM.py rename to model/CringeDenoiser.py index 792682c..2a33395 100644 --- a/model/CringeLDM.py +++ b/model/CringeDenoiser.py @@ -1,59 +1,15 @@ import pytorch_lightning as pl import torch import torch.nn as nn -import transformers -from torch import Tensor from torch.nn import functional as F -from torch.utils.data import DataLoader -from torch.utils.data import random_split from transformers.models.bert.modeling_bert import BertModel from transformers.models.bert.tokenization_bert import BertTokenizer -from transformers.utils.generic import ModelOutput -from transformers.modeling_outputs import BaseModelOutput from model.cringe.unet import UNet +from model.CringeBERT import CringeBERTWrapper from model.CringeVAE import CringeVAEModel - -class CringeBERTWrapper: - """ - BERT Wrapper - - This is a wrapper for the BERT model. Ideally would be trained from the same - dataset as the LDM model, but for now we just use the pretrained BERT model. - """ - - def loadModel(self, cpu): - self.bert_model = BertModel.from_pretrained( - 'bert-base-uncased') # type: ignore - if torch.cuda.is_available() & (not cpu): - self.bert_model = self.bert_model.cuda() # type: ignore - self.bert_tokenizer = BertTokenizer.from_pretrained( - 'bert-base-uncased') # type: ignore - - def __init__(self, cpu=False): - self.loadModel(cpu) - pass - - def model_output(self, input_ids: torch.Tensor): - with torch.no_grad(): - if torch.cuda.is_available(): - input_ids = input_ids.cuda() - output = self.bert_model(input_ids) # type: ignore - q = output.last_hidden_state - return q.unsqueeze(0) - - def inference(self, query): - with torch.no_grad(): - # Encode the text using BERT - input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \ - .unsqueeze(0) # Add batch dimension - # Normalise so that all values are between 0 and 1 - input_ids = (input_ids + 1) / 2 - return self.model_output(input_ids) - - class CringeBERTEncoder(pl.LightningModule): """ diff --git a/requirements.txt b/requirements.txt index f5f98ec..a8c7018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch torchvision torchaudio pytorch_lightning +lightning-transformers transformers numpy matplotlib \ No newline at end of file diff --git a/train.py b/train.py index 828455b..5b9136d 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ from data.dirtycollate import dirty_collate -from model.CringeLDM import CringeDenoiserModel +from model.CringeDenoiser import CringeDenoiserModel from model.CringeVAE import CringeVAEModel from data.unsplashlite import UnsplashLiteDataset from utils import RegularCheckpoint, train_save_checkpoint diff --git a/utils.py b/utils.py index 94c048f..671f0ab 100644 --- a/utils.py +++ b/utils.py @@ -5,7 +5,7 @@ import torch import torchvision -from model.CringeLDM import CringeDenoiserModel +from model.CringeDenoiser import CringeDenoiserModel from PIL import Image from pytorch_lightning.callbacks import ModelCheckpoint