-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: multimodal clip is in canary because of torchtext dependency so this adds the hf version Notably this PR also makes it possible to support dict based inputs to `get_module()` which is very common in HF code Pull Request resolved: #1921 Reviewed By: kartikayk, xuzhao9 Differential Revision: D49584110 Pulled By: msaroufim fbshipit-source-id: 34bc581515c860b05018413004b9b8067709fedc
- Loading branch information
1 parent
b7e9404
commit 6fef32d
Showing
9 changed files
with
121 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions
2
torchbenchmark/canary_models/torchmultimodal_clip/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
git+https://github.com/facebookresearch/multimodal.git | ||
torchtext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import os | ||
from ...util.model import BenchmarkModel | ||
from PIL import Image | ||
import requests | ||
|
||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
|
||
|
||
class ContrastiveLossWithTemperature(nn.Module): | ||
def __init__(self, temperature=0.07): | ||
super(ContrastiveLossWithTemperature, self).__init__() | ||
self.temperature = temperature | ||
|
||
def forward(self, image_embeddings, text_embeddings): | ||
# Ensure batch sizes are equal | ||
assert image_embeddings.size(0) == text_embeddings.size(0), "Batch sizes of image and text embeddings should be the same" | ||
|
||
# Compute the similarity between image and text embeddings | ||
logits = torch.matmul(image_embeddings, text_embeddings.T) / self.temperature | ||
|
||
# Compute the labels for the positive pairs | ||
labels = torch.arange(logits.size(0)).to(image_embeddings.device) | ||
|
||
# Compute the contrastive loss | ||
loss = F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels) | ||
return loss / 2 | ||
|
||
|
||
class Model(BenchmarkModel): | ||
DEFAULT_EVAL_BSIZE = 32 | ||
DEFAULT_TRAIN_BSIZE = 32 | ||
|
||
def __init__(self, test, device, batch_size=1, extra_args=[]): | ||
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args) | ||
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | ||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | ||
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
text = "the dog is here" | ||
images = [image] * self.batch_size | ||
texts = [text] * self.batch_size | ||
self.inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) | ||
|
||
# dict_keys(['input_ids', 'attention_mask', 'pixel_values']) | ||
for key in self.inputs: | ||
self.inputs[key] = self.inputs[key].to(self.device) | ||
|
||
# Add the loss function and optimizer | ||
self.loss_fn = ContrastiveLossWithTemperature() | ||
self.optimizer = torch.optim.AdamW( | ||
list(self.model.parameters()) + list(self.loss_fn.parameters()), | ||
lr=5.0e-4, | ||
weight_decay=1.0e-4, | ||
eps=1.0e-6, | ||
) | ||
|
||
|
||
|
||
def get_module(self): | ||
return self.model, self.inputs | ||
|
||
|
||
def train(self): | ||
image_tensor = self.inputs["pixel_values"] | ||
text_tensor = self.inputs["input_ids"] | ||
total_loss = 0 | ||
self.optimizer.zero_grad() | ||
|
||
# Forward pass | ||
outputs = self.model(**self.inputs) | ||
image_embedding = outputs.image_embeds | ||
text_embedding = outputs.text_embeds | ||
|
||
# Compute the loss | ||
loss = self.loss_fn(image_embedding, text_embedding) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
total_loss += loss.item() | ||
|
||
# Return the average loss | ||
return total_loss / len(text_tensor) | ||
|
||
|
||
def eval(self): | ||
return self.model(**self.inputs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
devices: | ||
NVIDIA A100-SXM4-40GB: | ||
eval_batch_size: 32 | ||
eval_benchmark: false | ||
eval_deterministic: false | ||
eval_nograd: true | ||
train_benchmark: false | ||
train_deterministic: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters