Skip to content

Commit

Permalink
refactor: Complete de-cuda. Stuck at BERT [#9]
Browse files Browse the repository at this point in the history
  • Loading branch information
pizzabug committed Jan 9, 2023
1 parent d679b64 commit f1d2a61
Show file tree
Hide file tree
Showing 47 changed files with 120 additions and 58 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Binary file removed checkpoints/ldm/sample Poppy seeds and flowers.png
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Binary file removed checkpoints/ldm/sample Woman exploring a forest.png
Diff not rendered.
4 changes: 2 additions & 2 deletions data/unsplashlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset
from torchvision import transforms

from model.CringeBERT import CringeBERTWrapper
from model.OldCringeBERT import OldCringeBERTWrapper
from utils import *

class UnsplashLiteDataset(Dataset):
Expand All @@ -19,7 +19,7 @@ def __init__(self, root_dir, transform=None, img_dim=256):

self.im_dimension = img_dim

bertWrapper = CringeBERTWrapper()
bertWrapper = OldCringeBERTWrapper()

# Get max length
self.text_max = 512
Expand Down
95 changes: 60 additions & 35 deletions model/CringeBERT.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,69 @@
import pytorch_lightning as pl
import torch

from torch import Tensor
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from torch.nn import functional as F

from model.unet.unet import UNet

class CringeBERTWrapper:

class CringeBERTModel(pl.LightningModule):
"""
BERT Wrapper
CringeBERTModel
This is a wrapper for the BERT model. Ideally would be trained from the same
dataset as the LDM model, but for now we just use the pretrained BERT model.
This is the VAE model. This is used as a prior to the denoiser module.
"""
def __init__(self, dimensions = [
32, 64, 128, 256
], hparams = None, has_cross_attention = False, img_dim = 512):
super().__init__()

self.img_dim = img_dim
self.vae_module = UNet(dimensions=dimensions, hparams=hparams, has_cross_attention=has_cross_attention)

def forward(self, x):
x = self.vae_module(x)
return x

def configure_optimizers(self):
"""
configure_optimizers
This is the optimizer for the model.
"""
optimizer = torch.optim.Adam(self.parameters(), lr=5e-6)
return optimizer

def training_step(self, train_batch, batch_idx):
"""
training_step
"""
# Grab batch
y, _ = train_batch

# Skip if image is None
if y is None:
return None

# Forward pass
y_hat = self.forward(y)
loss = F.l1_loss(y_hat, y)
self.log('train_loss', loss)

# Skip if resulting loss is NaN or Inf
if torch.isnan(loss) or torch.isinf(loss):
return None

return loss

def validation_step(self, val_batch, batch_idx):
"""
validation_step
"""
# Grab batch
y, _ = val_batch

def loadModel(self, cpu):
self.bert_model = BertModel.from_pretrained(
'bert-base-uncased') # type: ignore
if torch.cuda.is_available() & (not cpu):
self.bert_model = self.bert_model.cuda() # type: ignore
self.bert_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased') # type: ignore

def __init__(self, cpu=False):
self.loadModel(cpu)
pass

def model_output(self, input_ids: torch.Tensor):
with torch.no_grad():
if torch.cuda.is_available():
input_ids = input_ids.cuda()
output = self.bert_model(input_ids) # type: ignore
q = output.last_hidden_state
return q.unsqueeze(0)

def inference(self, query):
with torch.no_grad():
# Encode the text using BERT
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \
.unsqueeze(0) # Add batch dimension
# Normalise so that all values are between 0 and 1
input_ids = (input_ids + 1) / 2
return self.model_output(input_ids)
# Forward pass
y_hat = self.forward(y)
loss = F.l1_loss(y_hat, y)
self.log('val_loss', loss)
return loss
2 changes: 1 addition & 1 deletion model/CringeCLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from model.cringe.unet import UNet
from model.unet.unet import UNet


class CringeCLIPModel(pl.LightningDataModule):
Expand Down
12 changes: 6 additions & 6 deletions model/CringeDenoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import functional as F

from utils import add_noise
from model.CringeBERT import CringeBERTWrapper
from model.OldCringeBERT import OldCringeBERTWrapper
from model.CringeVAE import CringeVAEModel
from model.unet.unet import UNet

Expand All @@ -30,7 +30,7 @@ def __init__(self, hparams=None, vae_model: CringeVAEModel | None = None, diffus
This should be an integrated part of the model
in the future
"""
self.bertWrapper = CringeBERTWrapper()
self.bertWrapper = OldCringeBERTWrapper()

# Diffusion UNet
self.UNet = UNet(
Expand Down Expand Up @@ -159,10 +159,10 @@ def forward_with_q(self, query, x=None, steps=1):

q = self.bertWrapper.model_output(q)

if torch.cuda.is_available():
q = q.cuda()
if (x != None):
x = x.cuda()
# if torch.cuda.is_available():
# q = q.cuda()
# if (x != None):
# x = x.cuda()

# Forward pass
return self.forward(q, x, steps)
10 changes: 1 addition & 9 deletions model/CringeVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def training_step(self, train_batch, batch_idx):
if y is None:
return None

# Cuda up if needed
if torch.cuda.is_available():
y = y.cuda()

# Forward pass
y_hat = self.forward(y)
loss = F.l1_loss(y_hat, y)
Expand All @@ -66,12 +62,8 @@ def validation_step(self, val_batch, batch_idx):
# Grab batch
y, _ = val_batch

# Cuda up if needed
if torch.cuda.is_available():
y = y.cuda()

# Forward pass
y_hat = self.forward(y)
loss = F.l1_loss(y_hat, y)
self.log('val_loss', loss)
return loss
return loss
44 changes: 44 additions & 0 deletions model/OldCringeBERT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytorch_lightning as pl
import torch

from torch import Tensor
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer


class OldCringeBERTWrapper:
"""
BERT Wrapper
This is a wrapper for the BERT model. Ideally would be trained from the same
dataset as the LDM model, but for now we just use the pretrained BERT model.
"""

def loadModel(self, cpu):
self.bert_model = BertModel.from_pretrained(
'bert-base-uncased') # type: ignore
if torch.cuda.is_available() & (not cpu):
self.bert_model = self.bert_model.cuda() # type: ignore
self.bert_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased') # type: ignore

def __init__(self, cpu=False):
self.loadModel(cpu)
pass

def model_output(self, input_ids: torch.Tensor):
with torch.no_grad():
if torch.cuda.is_available():
input_ids = input_ids.cuda()
output = self.bert_model(input_ids) # type: ignore
q = output.last_hidden_state
return q.unsqueeze(0)

def inference(self, query):
with torch.no_grad():
# Encode the text using BERT
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \
.unsqueeze(0) # Add batch dimension
# Normalise so that all values are between 0 and 1
input_ids = (input_ids + 1) / 2
return self.model_output(input_ids)
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
lightning-transformers
matplotlib
numpy
open_clip_torch
pytorch_lightning
torch
torchvision
torchaudio
pytorch_lightning
lightning-transformers
transformers
numpy
matplotlib
transformers

0 comments on commit f1d2a61

Please sign in to comment.