Skip to content

Commit

Permalink
feat: Implement basic datasets for CLIP [#7]
Browse files Browse the repository at this point in the history
  • Loading branch information
pizzabug committed Jan 11, 2023
1 parent 49082fb commit d7b0b0b
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 31 deletions.
72 changes: 72 additions & 0 deletions data/unsplashlite_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Data Loader for Unsplash Lite Dataset

import csv
import numpy as np
import os
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from model.OldCringeBERT import OldCringeBERTWrapper
from utils import *

class UnsplashLiteDataset(Dataset):
def __init__(self, root_dir, transform=None, img_dim=256):
self.image_paths = []
self.image_captions = []

self.im_dimension = img_dim

bertWrapper = OldCringeBERTWrapper()

# Get max length
self.text_max = 512

# Open the CSV file and read the image path from it
with open(root_dir + '/manifest.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
image_path = root_dir + '/' + row[0]
image_caption = row[1]
image_caption = torch.tensor(bertWrapper.bert_tokenizer.encode(image_caption)).unsqueeze(0)

#if (image_caption.size()[1] > self.text_max):
# self.text_max = image_caption.size()[1]
if (image_caption.size()[1] >= self.text_max):
image_caption = image_caption[:, :self.text_max]
else:
image_caption = torch.nn.functional.pad(image_caption, (0, self.text_max - image_caption.size()[1]), 'constant', 0)

image_caption = image_caption.squeeze(0)

self.image_paths.append(image_path)
self.image_captions.append(image_caption)


def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
path = self.image_paths[idx]
if (not os.path.exists(path)):
return None, None
else:
x = Image.open(path)
x = x.resize((self.im_dimension, self.im_dimension))
x = np.array(x)
if x.shape != (self.im_dimension, self.im_dimension, 3):
print(f"Warning: image shape is not ({self.im_dimension}, {self.im_dimension}, 3). Skipping")
print(x.shape)
return None, None

x = convert_to_tensor(x)
x = x.squeeze(0)
if x.shape != (3, self.im_dimension, self.im_dimension):
print(f"Warning: image shape is not (3, {self.im_dimension}, {self.im_dimension}). Skipping")
print(x.shape)
return None, None

q = self.image_captions[idx]
return x, q
67 changes: 67 additions & 0 deletions data/unsplashlite_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Data Loader for Unsplash Lite Dataset

import csv
import numpy as np
import os
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from model.CringeCLIP import CringeCLIPModel
from utils import *

class UnsplashLiteDataset(Dataset):
def __init__(self, root_dir, transform=None, img_dim=256):
self.image_paths = []
self.image_captions = []

self.im_dimension = img_dim

# Tokenise babie
clip_model = CringeCLIPModel(just_the_tokenizer=True)

# Get max length
self.text_max = 512

# Open the CSV file and read the image path from it
with open(root_dir + '/manifest.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
image_path = root_dir + '/' + row[0]
image_caption = row[1]
image_caption = clip_model.tokenizer(image_caption)
image_caption = image_caption.squeeze(0)

self.image_paths.append(image_path)
self.image_captions.append(image_caption)

# Flush out the model
del clip_model

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
path = self.image_paths[idx]
if (not os.path.exists(path)):
return None, None
else:
x = Image.open(path)
x = x.resize((self.im_dimension, self.im_dimension))
x = np.array(x)
if x.shape != (self.im_dimension, self.im_dimension, 3):
print(f"Warning: image shape is not ({self.im_dimension}, {self.im_dimension}, 3). Skipping")
print(x.shape)
return None, None

x = convert_to_tensor(x)
x = x.squeeze(0)
if x.shape != (3, self.im_dimension, self.im_dimension):
print(f"Warning: image shape is not (3, {self.im_dimension}, {self.im_dimension}). Skipping")
print(x.shape)
return None, None

q = self.image_captions[idx]
return x, q
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytorch_lightning as pl
import torch

from model.CringeDenoiser import CringeDenoiserModel
from model.CringeDenoiserBert import CringeDenoiserModel
from model.CringeVAE import CringeVAEModel
from PIL import Image

Expand Down
5 changes: 3 additions & 2 deletions model/CringeCLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ class CringeCLIPModel(pl.LightningModule):
CringeCLIP
"""

def __init__(self, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512):
def __init__(self, just_the_tokenizer=False, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512):
super().__init__()

self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore
if not just_the_tokenizer:
self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore
self.tokenizer = open_clip.get_tokenizer("RN50") # type: ignore

def forward(self, text = None, image = None):
Expand Down
42 changes: 26 additions & 16 deletions model/CringeDenoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.nn import functional as F

from utils import add_noise
from model.OldCringeBERT import OldCringeBERTWrapper
from model.CringeVAE import CringeVAEModel
from model.CringeCLIP import CringeCLIPModel
from model.unet.unet import UNet


Expand All @@ -19,18 +19,12 @@ class CringeDenoiserModel(pl.LightningModule):

def __init__(self, hparams=None, vae_model: CringeVAEModel | None = None, diffuser_shapes=[
32, 64, 128, 256
], img_dim=256):
], img_dim=256, clip_model = None):
super().__init__()
self.img_dim = img_dim
self.dropout = 0.02
self.vae_model: CringeVAEModel = vae_model # type: ignore

"""
BERT Wrapper for the text encoding
This should be an integrated part of the model
in the future
"""
self.bertWrapper = OldCringeBERTWrapper()
self.clip_model: CringeCLIPModel = clip_model # type: ignore

# Diffusion UNet
self.UNet = UNet(
Expand Down Expand Up @@ -110,12 +104,15 @@ def training_step(self, train_batch, batch_idx):
# q = q.cuda()

# Get q
q = self.bertWrapper.model_output(q)
with torch.no_grad():
if q is not None:
q = self.clip_model.forward(text=q)
else:
# Preprocess image and send it through clip. Right now a bit hard, so TODO.
return None

# Generate x batch, which is a slightly noisier version of y
x = add_noise(y)

#x = torch.randn(y.shape[0], 3, self.img_dim, self.img_dim)
#x = x.to(y)

# Forward pass
y_hat = self.forward(q=q, x=x, steps=1)
Expand All @@ -141,7 +138,13 @@ def validation_step(self, val_batch, batch_idx):
# q = q.cuda()

# Get q
q = self.bertWrapper.model_output(q)
with torch.no_grad():
if q is not None:
q = self.clip_model.forward(text=q)
else:
# Preprocess image and send it through clip. Right now a bit hard, so TODO.
return None

# Forward pass
y_hat = self(q)
loss = F.l1_loss(y_hat, y)
Expand All @@ -152,12 +155,19 @@ def forward_with_q(self, query, x=None, steps=1):

# Get the BERT output
q = torch.tensor(
self.bertWrapper.bert_tokenizer.encode(query)).unsqueeze(0)
self.clip_model.tokenizer([query,])).unsqueeze(0)

# if torch.cuda.is_available():
# q = q.cuda()

q = self.bertWrapper.model_output(q)
# Get q
with torch.no_grad():
if q is not None:
q = self.clip_model.forward(text=q)
else:
# Preprocess image and send it through clip. Right now a bit hard, so TODO.
return None


# if torch.cuda.is_available():
# q = q.cuda()
Expand Down
Loading

0 comments on commit d7b0b0b

Please sign in to comment.