Skip to content

Commit

Permalink
refactor: free BERT! [#6]
Browse files Browse the repository at this point in the history
  • Loading branch information
pizzabug committed Jan 2, 2023
1 parent 0682d1f commit bb451f5
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 50 deletions.
2 changes: 1 addition & 1 deletion data/unsplashlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset
from torchvision import transforms

from model.CringeLDM import CringeBERTWrapper
from model.CringeBERT import CringeBERTWrapper
from utils import *

class UnsplashLiteDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytorch_lightning as pl
import torch

from model.CringeLDM import CringeDenoiserModel
from model.CringeDenoiser import CringeDenoiserModel
from model.CringeVAE import CringeVAEModel
from PIL import Image

Expand Down
44 changes: 43 additions & 1 deletion model/CringeBERT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,47 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.utils.generic import ModelOutput
from transformers.modeling_outputs import BaseModelOutput

from model.cringe.unet import UNet

from model.cringe.unet import UNet

class CringeBERTWrapper:
"""
BERT Wrapper
This is a wrapper for the BERT model. Ideally would be trained from the same
dataset as the LDM model, but for now we just use the pretrained BERT model.
"""

def loadModel(self, cpu):
self.bert_model = BertModel.from_pretrained(
'bert-base-uncased') # type: ignore
if torch.cuda.is_available() & (not cpu):
self.bert_model = self.bert_model.cuda() # type: ignore
self.bert_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased') # type: ignore

def __init__(self, cpu=False):
self.loadModel(cpu)
pass

def model_output(self, input_ids: torch.Tensor):
with torch.no_grad():
if torch.cuda.is_available():
input_ids = input_ids.cuda()
output = self.bert_model(input_ids) # type: ignore
q = output.last_hidden_state
return q.unsqueeze(0)

def inference(self, query):
with torch.no_grad():
# Encode the text using BERT
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \
.unsqueeze(0) # Add batch dimension
# Normalise so that all values are between 0 and 1
input_ids = (input_ids + 1) / 2
return self.model_output(input_ids)
46 changes: 1 addition & 45 deletions model/CringeLDM.py → model/CringeDenoiser.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,15 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
import transformers

from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.utils.generic import ModelOutput
from transformers.modeling_outputs import BaseModelOutput

from model.cringe.unet import UNet
from model.CringeBERT import CringeBERTWrapper
from model.CringeVAE import CringeVAEModel


class CringeBERTWrapper:
"""
BERT Wrapper
This is a wrapper for the BERT model. Ideally would be trained from the same
dataset as the LDM model, but for now we just use the pretrained BERT model.
"""

def loadModel(self, cpu):
self.bert_model = BertModel.from_pretrained(
'bert-base-uncased') # type: ignore
if torch.cuda.is_available() & (not cpu):
self.bert_model = self.bert_model.cuda() # type: ignore
self.bert_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased') # type: ignore

def __init__(self, cpu=False):
self.loadModel(cpu)
pass

def model_output(self, input_ids: torch.Tensor):
with torch.no_grad():
if torch.cuda.is_available():
input_ids = input_ids.cuda()
output = self.bert_model(input_ids) # type: ignore
q = output.last_hidden_state
return q.unsqueeze(0)

def inference(self, query):
with torch.no_grad():
# Encode the text using BERT
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \
.unsqueeze(0) # Add batch dimension
# Normalise so that all values are between 0 and 1
input_ids = (input_ids + 1) / 2
return self.model_output(input_ids)


class CringeBERTEncoder(pl.LightningModule):

"""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torch
torchvision
torchaudio
pytorch_lightning
lightning-transformers
transformers
numpy
matplotlib
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from data.dirtycollate import dirty_collate

from model.CringeLDM import CringeDenoiserModel
from model.CringeDenoiser import CringeDenoiserModel
from model.CringeVAE import CringeVAEModel
from data.unsplashlite import UnsplashLiteDataset
from utils import RegularCheckpoint, train_save_checkpoint
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torchvision

from model.CringeLDM import CringeDenoiserModel
from model.CringeDenoiser import CringeDenoiserModel
from PIL import Image
from pytorch_lightning.callbacks import ModelCheckpoint

Expand Down

0 comments on commit bb451f5

Please sign in to comment.