diff --git a/data/unsplashlite_bert.py b/data/unsplashlite_bert.py new file mode 100644 index 0000000..4769510 --- /dev/null +++ b/data/unsplashlite_bert.py @@ -0,0 +1,72 @@ +# Data Loader for Unsplash Lite Dataset + +import csv +import numpy as np +import os +import torch + +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from model.OldCringeBERT import OldCringeBERTWrapper +from utils import * + +class UnsplashLiteDataset(Dataset): + def __init__(self, root_dir, transform=None, img_dim=256): + self.image_paths = [] + self.image_captions = [] + + self.im_dimension = img_dim + + bertWrapper = OldCringeBERTWrapper() + + # Get max length + self.text_max = 512 + + # Open the CSV file and read the image path from it + with open(root_dir + '/manifest.csv', 'r') as file: + reader = csv.reader(file) + for row in reader: + image_path = root_dir + '/' + row[0] + image_caption = row[1] + image_caption = torch.tensor(bertWrapper.bert_tokenizer.encode(image_caption)).unsqueeze(0) + + #if (image_caption.size()[1] > self.text_max): + # self.text_max = image_caption.size()[1] + if (image_caption.size()[1] >= self.text_max): + image_caption = image_caption[:, :self.text_max] + else: + image_caption = torch.nn.functional.pad(image_caption, (0, self.text_max - image_caption.size()[1]), 'constant', 0) + + image_caption = image_caption.squeeze(0) + + self.image_paths.append(image_path) + self.image_captions.append(image_caption) + + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + path = self.image_paths[idx] + if (not os.path.exists(path)): + return None, None + else: + x = Image.open(path) + x = x.resize((self.im_dimension, self.im_dimension)) + x = np.array(x) + if x.shape != (self.im_dimension, self.im_dimension, 3): + print(f"Warning: image shape is not ({self.im_dimension}, {self.im_dimension}, 3). Skipping") + print(x.shape) + return None, None + + x = convert_to_tensor(x) + x = x.squeeze(0) + if x.shape != (3, self.im_dimension, self.im_dimension): + print(f"Warning: image shape is not (3, {self.im_dimension}, {self.im_dimension}). Skipping") + print(x.shape) + return None, None + + q = self.image_captions[idx] + return x, q \ No newline at end of file diff --git a/data/unsplashlite_clip.py b/data/unsplashlite_clip.py new file mode 100644 index 0000000..fedc51c --- /dev/null +++ b/data/unsplashlite_clip.py @@ -0,0 +1,67 @@ +# Data Loader for Unsplash Lite Dataset + +import csv +import numpy as np +import os +import torch + +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from model.CringeCLIP import CringeCLIPModel +from utils import * + +class UnsplashLiteDataset(Dataset): + def __init__(self, root_dir, transform=None, img_dim=256): + self.image_paths = [] + self.image_captions = [] + + self.im_dimension = img_dim + + # Tokenise babie + clip_model = CringeCLIPModel(just_the_tokenizer=True) + + # Get max length + self.text_max = 512 + + # Open the CSV file and read the image path from it + with open(root_dir + '/manifest.csv', 'r') as file: + reader = csv.reader(file) + for row in reader: + image_path = root_dir + '/' + row[0] + image_caption = row[1] + image_caption = clip_model.tokenizer(image_caption) + image_caption = image_caption.squeeze(0) + + self.image_paths.append(image_path) + self.image_captions.append(image_caption) + + # Flush out the model + del clip_model + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + path = self.image_paths[idx] + if (not os.path.exists(path)): + return None, None + else: + x = Image.open(path) + x = x.resize((self.im_dimension, self.im_dimension)) + x = np.array(x) + if x.shape != (self.im_dimension, self.im_dimension, 3): + print(f"Warning: image shape is not ({self.im_dimension}, {self.im_dimension}, 3). Skipping") + print(x.shape) + return None, None + + x = convert_to_tensor(x) + x = x.squeeze(0) + if x.shape != (3, self.im_dimension, self.im_dimension): + print(f"Warning: image shape is not (3, {self.im_dimension}, {self.im_dimension}). Skipping") + print(x.shape) + return None, None + + q = self.image_captions[idx] + return x, q \ No newline at end of file diff --git a/inference.py b/inference.py index bbbe516..94d3f1d 100644 --- a/inference.py +++ b/inference.py @@ -3,7 +3,7 @@ import pytorch_lightning as pl import torch -from model.CringeDenoiser import CringeDenoiserModel +from model.CringeDenoiserBert import CringeDenoiserModel from model.CringeVAE import CringeVAEModel from PIL import Image diff --git a/model/CringeCLIP.py b/model/CringeCLIP.py index 8ee2289..dd27f64 100644 --- a/model/CringeCLIP.py +++ b/model/CringeCLIP.py @@ -10,10 +10,11 @@ class CringeCLIPModel(pl.LightningModule): CringeCLIP """ - def __init__(self, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512): + def __init__(self, just_the_tokenizer=False, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512): super().__init__() - self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore + if not just_the_tokenizer: + self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore self.tokenizer = open_clip.get_tokenizer("RN50") # type: ignore def forward(self, text = None, image = None): diff --git a/model/CringeDenoiser.py b/model/CringeDenoiser.py index c57af33..95e8bde 100644 --- a/model/CringeDenoiser.py +++ b/model/CringeDenoiser.py @@ -5,8 +5,8 @@ from torch.nn import functional as F from utils import add_noise -from model.OldCringeBERT import OldCringeBERTWrapper from model.CringeVAE import CringeVAEModel +from model.CringeCLIP import CringeCLIPModel from model.unet.unet import UNet @@ -19,18 +19,12 @@ class CringeDenoiserModel(pl.LightningModule): def __init__(self, hparams=None, vae_model: CringeVAEModel | None = None, diffuser_shapes=[ 32, 64, 128, 256 - ], img_dim=256): + ], img_dim=256, clip_model = None): super().__init__() self.img_dim = img_dim self.dropout = 0.02 self.vae_model: CringeVAEModel = vae_model # type: ignore - - """ - BERT Wrapper for the text encoding - This should be an integrated part of the model - in the future - """ - self.bertWrapper = OldCringeBERTWrapper() + self.clip_model: CringeCLIPModel = clip_model # type: ignore # Diffusion UNet self.UNet = UNet( @@ -110,12 +104,15 @@ def training_step(self, train_batch, batch_idx): # q = q.cuda() # Get q - q = self.bertWrapper.model_output(q) + with torch.no_grad(): + if q is not None: + q = self.clip_model.forward(text=q) + else: + # Preprocess image and send it through clip. Right now a bit hard, so TODO. + return None + # Generate x batch, which is a slightly noisier version of y x = add_noise(y) - - #x = torch.randn(y.shape[0], 3, self.img_dim, self.img_dim) - #x = x.to(y) # Forward pass y_hat = self.forward(q=q, x=x, steps=1) @@ -141,7 +138,13 @@ def validation_step(self, val_batch, batch_idx): # q = q.cuda() # Get q - q = self.bertWrapper.model_output(q) + with torch.no_grad(): + if q is not None: + q = self.clip_model.forward(text=q) + else: + # Preprocess image and send it through clip. Right now a bit hard, so TODO. + return None + # Forward pass y_hat = self(q) loss = F.l1_loss(y_hat, y) @@ -152,12 +155,19 @@ def forward_with_q(self, query, x=None, steps=1): # Get the BERT output q = torch.tensor( - self.bertWrapper.bert_tokenizer.encode(query)).unsqueeze(0) + self.clip_model.tokenizer([query,])).unsqueeze(0) # if torch.cuda.is_available(): # q = q.cuda() - q = self.bertWrapper.model_output(q) + # Get q + with torch.no_grad(): + if q is not None: + q = self.clip_model.forward(text=q) + else: + # Preprocess image and send it through clip. Right now a bit hard, so TODO. + return None + # if torch.cuda.is_available(): # q = q.cuda() diff --git a/model/CringeDenoiserBert.py b/model/CringeDenoiserBert.py new file mode 100644 index 0000000..b80bec8 --- /dev/null +++ b/model/CringeDenoiserBert.py @@ -0,0 +1,168 @@ +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from torch.nn import functional as F + +from utils import add_noise +from model.OldCringeBERT import OldCringeBERTWrapper +from model.CringeVAE import CringeVAEModel +from model.CringeCLIP import CringeCLIPModel +from model.unet.unet import UNet + + +class CringeDenoiserModel(pl.LightningModule): + """ + Denoiser Model + + This is the definition of the LDM denoiser model. + """ + + def __init__(self, hparams=None, vae_model: CringeVAEModel | None = None, diffuser_shapes=[ + 32, 64, 128, 256 + ], img_dim=256, clip_model = None): + super().__init__() + self.img_dim = img_dim + self.dropout = 0.02 + self.vae_model: CringeVAEModel = vae_model # type: ignore + self.clip_model: CringeCLIPModel = clip_model # type: ignore + + """ + BERT Wrapper for the text encoding + This should be an integrated part of the model + in the future + """ + self.bertWrapper = OldCringeBERTWrapper() + + # Diffusion UNet + self.UNet = UNet( + dimensions=diffuser_shapes, + hparams=hparams, + has_cross_attention=True + ) + + # Image space decoder + self.imageSpaceDecoder = nn.Sequential( + nn.Conv2d(3, 6, 12, padding='same'), + nn.BatchNorm2d(6), + nn.ReLU(), + nn.Conv2d(6, 3, 24, padding='same'), + nn.BatchNorm2d(3), + nn.Dropout(self.dropout), + nn.ReLU(), + nn.Conv2d(3, 3, 12, padding='same'), + ) + + def forward(self, q, x=None, steps=20): + """ + forward + + self: The model + q: Query tensor from BERT + x: Image tensor + steps: Number of steps to denoise the image + """ + + # if torch.cuda.is_available(): + # x = x.cuda() + # q = q.cuda() + + # Load the image + if x is None: + # Generate noise; q's batch dimension is at 1th element + x = torch.randn(q.shape[1], 3, self.img_dim, self.img_dim) + x = x.to(q) + + # Put the image through the VAE + with torch.no_grad(): + x = self.vae_model(x) + + # We denoise for multiple steps + for i in range(steps): + # This is the latent space + x = self.UNet(x, q) + # Image space decoder + x = self.imageSpaceDecoder(x) + + return x + + def configure_optimizers(self): + """ + configure_optimizers + + This is the optimizer for the model. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=5e-5) + return optimizer + + def training_step(self, train_batch, batch_idx): + """ + training_step + """ + # Grab batch + y, q = train_batch + + # Skip if image is None + if y is None: + return None + + # Cuda up if needed + # if torch.cuda.is_available(): + # y = y.cuda() + # q = q.cuda() + + # Get q + q = self.bertWrapper.model_output(q) + + # Generate x batch, which is a slightly noisier version of y + x = add_noise(y) + + # Forward pass + y_hat = self.forward(q=q, x=x, steps=1) + loss = F.l1_loss(y_hat, y) + self.log('train_loss', loss) + + # Skip if resulting loss is NaN or Inf + if torch.isnan(loss) or torch.isinf(loss): + return None + + return loss + + def validation_step(self, val_batch, batch_idx): + """ + validation_step + """ + # Grab batch + y, q = val_batch + + # Cuda up if needed + # if torch.cuda.is_available(): + # y = y.cuda() + # q = q.cuda() + + # Get q + q = self.bertWrapper.model_output(q) + # Forward pass + y_hat = self(q) + loss = F.l1_loss(y_hat, y) + self.log('val_loss', loss) + return loss + + def forward_with_q(self, query, x=None, steps=1): + + # Get the BERT output + q = torch.tensor( + self.bertWrapper.bert_tokenizer.encode(query)).unsqueeze(0) + + # if torch.cuda.is_available(): + # q = q.cuda() + + q = self.bertWrapper.model_output(q) + + # if torch.cuda.is_available(): + # q = q.cuda() + # if (x != None): + # x = x.cuda() + + # Forward pass + return self.forward(q, x, steps) diff --git a/model/unet/unet.py b/model/unet/unet.py index b379ada..4f5b3bd 100644 --- a/model/unet/unet.py +++ b/model/unet/unet.py @@ -136,7 +136,7 @@ class UNet(nn.Module): def __init__(self, dimensions=[ 32, 64, 128, 256 - ], hparams=None, has_cross_attention=True): + ], hparams=None, has_cross_attention=True, query_channels=1024): super(UNet, self).__init__() self.dimensions = dimensions diff --git a/train.py b/train.py index 4dd1e69..81cdda8 100644 --- a/train.py +++ b/train.py @@ -8,10 +8,10 @@ from torch.utils.data import DataLoader from data.dirtycollate import dirty_collate - +from data.unsplashlite_clip import UnsplashLiteDataset +from model.CringeCLIP import CringeCLIPModel from model.CringeDenoiser import CringeDenoiserModel from model.CringeVAE import CringeVAEModel -from data.unsplashlite import UnsplashLiteDataset from utils import RegularCheckpoint, train_save_checkpoint os.environ['CUDA_VISIBLE_DEVICES'] ='0' @@ -24,16 +24,25 @@ def train_denoiser(): dataset = UnsplashLiteDataset(root_dir='/mnt/e/Source/unsplash-lite-corpus-preprocess/db', img_dim=img_dim) training_set, validation_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset)*0.2)]) - train_loader = DataLoader(training_set, batch_size=10, collate_fn=dirty_collate) - val_loader = DataLoader(validation_set, batch_size=10, collate_fn=dirty_collate) + train_loader = DataLoader(training_set, batch_size=1, collate_fn=dirty_collate) + val_loader = DataLoader(validation_set, batch_size=1, collate_fn=dirty_collate) - # Load checkpoint if it exists - vae_model = CringeVAEModel(dimensions=[16,32,64,128]).to("cuda:0" if torch.cuda.is_available() else "cpu") + # Load CLIP checkpoint if it exists + clip_model = CringeCLIPModel() #.to("cuda:0" if torch.cuda.is_available() else "cpu") + if (os.path.exists("checkpoints/clip/model.ckpt")): + clip_model.load_state_dict(torch.load("checkpoints/clip/model.ckpt")["state_dict"]) + + # Load VAE checkpoint if it exists + vae_model = CringeVAEModel(dimensions=[16,32,64,128]) #.to("cuda:0" if torch.cuda.is_available() else "cpu") if (os.path.exists("checkpoints/vae/model.ckpt")): vae_model.load_state_dict(torch.load("checkpoints/vae/model.ckpt")["state_dict"]) # Load checkpoint if it exists - denoiser_model = CringeDenoiserModel(vae_model=vae_model, diffuser_shapes=[32,64,128,256], img_dim=img_dim).to("cuda:0" if torch.cuda.is_available() else "cpu") + denoiser_model = CringeDenoiserModel( + vae_model=vae_model, + clip_model=clip_model, + diffuser_shapes=[32,64,128,256], + img_dim=img_dim) #.to("cuda:0" if torch.cuda.is_available() else "cpu") # Logger denoiser_logger = TensorBoardLogger("tb_logs", name="cringeldm") @@ -54,15 +63,15 @@ def train_denoiser(): accumulate_grad_batches=10, logger=denoiser_logger) while True: - try: + #try: # Load checkpoint if it exists if (os.path.exists("checkpoints/ldm/model.ckpt")): denoiser_trainer.fit(denoiser_model, train_loader, val_loader, ckpt_path="checkpoints/ldm/model.ckpt") else: denoiser_trainer.fit(denoiser_model, train_loader, val_loader) - except Exception as e: - tb = sys.exc_info()[2] - print(e.with_traceback(tb)) + #except Exception as e: + # tb = sys.exc_info()[2] + # print(e.with_traceback(tb)) def train_vae(): # hparams while i'm working on it