-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Complete de-cuda. Stuck at BERT [#9]
- Loading branch information
Showing
47 changed files
with
120 additions
and
58 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,69 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
|
||
from torch import Tensor | ||
from transformers.models.bert.modeling_bert import BertModel | ||
from transformers.models.bert.tokenization_bert import BertTokenizer | ||
from torch.nn import functional as F | ||
|
||
from model.unet.unet import UNet | ||
|
||
class CringeBERTWrapper: | ||
|
||
class CringeBERTModel(pl.LightningModule): | ||
""" | ||
BERT Wrapper | ||
CringeBERTModel | ||
This is a wrapper for the BERT model. Ideally would be trained from the same | ||
dataset as the LDM model, but for now we just use the pretrained BERT model. | ||
This is the VAE model. This is used as a prior to the denoiser module. | ||
""" | ||
def __init__(self, dimensions = [ | ||
32, 64, 128, 256 | ||
], hparams = None, has_cross_attention = False, img_dim = 512): | ||
super().__init__() | ||
|
||
self.img_dim = img_dim | ||
self.vae_module = UNet(dimensions=dimensions, hparams=hparams, has_cross_attention=has_cross_attention) | ||
|
||
def forward(self, x): | ||
x = self.vae_module(x) | ||
return x | ||
|
||
def configure_optimizers(self): | ||
""" | ||
configure_optimizers | ||
This is the optimizer for the model. | ||
""" | ||
optimizer = torch.optim.Adam(self.parameters(), lr=5e-6) | ||
return optimizer | ||
|
||
def training_step(self, train_batch, batch_idx): | ||
""" | ||
training_step | ||
""" | ||
# Grab batch | ||
y, _ = train_batch | ||
|
||
# Skip if image is None | ||
if y is None: | ||
return None | ||
|
||
# Forward pass | ||
y_hat = self.forward(y) | ||
loss = F.l1_loss(y_hat, y) | ||
self.log('train_loss', loss) | ||
|
||
# Skip if resulting loss is NaN or Inf | ||
if torch.isnan(loss) or torch.isinf(loss): | ||
return None | ||
|
||
return loss | ||
|
||
def validation_step(self, val_batch, batch_idx): | ||
""" | ||
validation_step | ||
""" | ||
# Grab batch | ||
y, _ = val_batch | ||
|
||
def loadModel(self, cpu): | ||
self.bert_model = BertModel.from_pretrained( | ||
'bert-base-uncased') # type: ignore | ||
if torch.cuda.is_available() & (not cpu): | ||
self.bert_model = self.bert_model.cuda() # type: ignore | ||
self.bert_tokenizer = BertTokenizer.from_pretrained( | ||
'bert-base-uncased') # type: ignore | ||
|
||
def __init__(self, cpu=False): | ||
self.loadModel(cpu) | ||
pass | ||
|
||
def model_output(self, input_ids: torch.Tensor): | ||
with torch.no_grad(): | ||
if torch.cuda.is_available(): | ||
input_ids = input_ids.cuda() | ||
output = self.bert_model(input_ids) # type: ignore | ||
q = output.last_hidden_state | ||
return q.unsqueeze(0) | ||
|
||
def inference(self, query): | ||
with torch.no_grad(): | ||
# Encode the text using BERT | ||
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \ | ||
.unsqueeze(0) # Add batch dimension | ||
# Normalise so that all values are between 0 and 1 | ||
input_ids = (input_ids + 1) / 2 | ||
return self.model_output(input_ids) | ||
# Forward pass | ||
y_hat = self.forward(y) | ||
loss = F.l1_loss(y_hat, y) | ||
self.log('val_loss', loss) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
|
||
from torch import Tensor | ||
from transformers.models.bert.modeling_bert import BertModel | ||
from transformers.models.bert.tokenization_bert import BertTokenizer | ||
|
||
|
||
class OldCringeBERTWrapper: | ||
""" | ||
BERT Wrapper | ||
This is a wrapper for the BERT model. Ideally would be trained from the same | ||
dataset as the LDM model, but for now we just use the pretrained BERT model. | ||
""" | ||
|
||
def loadModel(self, cpu): | ||
self.bert_model = BertModel.from_pretrained( | ||
'bert-base-uncased') # type: ignore | ||
if torch.cuda.is_available() & (not cpu): | ||
self.bert_model = self.bert_model.cuda() # type: ignore | ||
self.bert_tokenizer = BertTokenizer.from_pretrained( | ||
'bert-base-uncased') # type: ignore | ||
|
||
def __init__(self, cpu=False): | ||
self.loadModel(cpu) | ||
pass | ||
|
||
def model_output(self, input_ids: torch.Tensor): | ||
with torch.no_grad(): | ||
if torch.cuda.is_available(): | ||
input_ids = input_ids.cuda() | ||
output = self.bert_model(input_ids) # type: ignore | ||
q = output.last_hidden_state | ||
return q.unsqueeze(0) | ||
|
||
def inference(self, query): | ||
with torch.no_grad(): | ||
# Encode the text using BERT | ||
input_ids: Tensor = torch.tensor(self.bert_tokenizer.encode(query)) \ | ||
.unsqueeze(0) # Add batch dimension | ||
# Normalise so that all values are between 0 and 1 | ||
input_ids = (input_ids + 1) / 2 | ||
return self.model_output(input_ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
lightning-transformers | ||
matplotlib | ||
numpy | ||
open_clip_torch | ||
pytorch_lightning | ||
torch | ||
torchvision | ||
torchaudio | ||
pytorch_lightning | ||
lightning-transformers | ||
transformers | ||
numpy | ||
matplotlib | ||
transformers |