Skip to content

Commit

Permalink
hf and multimodal clip (#1921)
Browse files Browse the repository at this point in the history
Summary:
multimodal clip is in canary because of torchtext dependency so this adds the hf version

Notably this PR also makes it possible to support dict based inputs to `get_module()` which is very common in HF code

Pull Request resolved: #1921

Reviewed By: kartikayk, xuzhao9

Differential Revision: D49584110

Pulled By: msaroufim

fbshipit-source-id: 34bc581515c860b05018413004b9b8067709fedc
  • Loading branch information
msaroufim authored and facebook-github-bot committed Sep 26, 2023
1 parent b7e9404 commit 6fef32d
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 4 deletions.
3 changes: 3 additions & 0 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def worker(self) -> subprocess_worker.SubprocessWorker:
def model_details(self) -> bool:
return self._details

def __str__(self) -> str:
return f"ModelTask(Model Path: {self._model_path}, Metadata: {self._details.metadata})"

# =========================================================================
# == Import Model in the child process ====================================
# =========================================================================
Expand Down
1 change: 0 additions & 1 deletion torchbenchmark/canary_models/clip/requirements.txt

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
git+https://github.com/facebookresearch/multimodal.git
torchtext
95 changes: 95 additions & 0 deletions torchbenchmark/models/hf_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
from ...util.model import BenchmarkModel
from PIL import Image
import requests

from transformers import CLIPProcessor, CLIPModel



class ContrastiveLossWithTemperature(nn.Module):
def __init__(self, temperature=0.07):
super(ContrastiveLossWithTemperature, self).__init__()
self.temperature = temperature

def forward(self, image_embeddings, text_embeddings):
# Ensure batch sizes are equal
assert image_embeddings.size(0) == text_embeddings.size(0), "Batch sizes of image and text embeddings should be the same"

# Compute the similarity between image and text embeddings
logits = torch.matmul(image_embeddings, text_embeddings.T) / self.temperature

# Compute the labels for the positive pairs
labels = torch.arange(logits.size(0)).to(image_embeddings.device)

# Compute the contrastive loss
loss = F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)
return loss / 2


class Model(BenchmarkModel):
DEFAULT_EVAL_BSIZE = 32
DEFAULT_TRAIN_BSIZE = 32

def __init__(self, test, device, batch_size=1, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = "the dog is here"
images = [image] * self.batch_size
texts = [text] * self.batch_size
self.inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)

# dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
for key in self.inputs:
self.inputs[key] = self.inputs[key].to(self.device)

# Add the loss function and optimizer
self.loss_fn = ContrastiveLossWithTemperature()
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.loss_fn.parameters()),
lr=5.0e-4,
weight_decay=1.0e-4,
eps=1.0e-6,
)



def get_module(self):
return self.model, self.inputs


def train(self):
image_tensor = self.inputs["pixel_values"]
text_tensor = self.inputs["input_ids"]
total_loss = 0
self.optimizer.zero_grad()

# Forward pass
outputs = self.model(**self.inputs)
image_embedding = outputs.image_embeds
text_embedding = outputs.text_embeds

# Compute the loss
loss = self.loss_fn(image_embedding, text_embedding)
loss.backward()
self.optimizer.step()

total_loss += loss.item()

# Return the average loss
return total_loss / len(text_tensor)


def eval(self):
return self.model(**self.inputs)

8 changes: 8 additions & 0 deletions torchbenchmark/models/hf_clip/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
16 changes: 13 additions & 3 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
from contextlib import contextmanager, ExitStack
from typing import Any, Dict, List, Optional
from collections.abc import Mapping


MAIN_RANDOM_SEED = 1337
# rounds for stableness tests
Expand Down Expand Up @@ -292,7 +294,7 @@ def torch_clone(x):

def clone_inputs(example_inputs):
import torch
if type(example_inputs) is dict:
if isinstance(example_inputs, Mapping):
res = dict(example_inputs)
for key, value in res.items():
assert isinstance(value, torch.Tensor)
Expand Down Expand Up @@ -351,14 +353,22 @@ def optimizer_step(optimizer):
optimizer.step()

def forward_pass(mod, inputs, contexts, _collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
with nested(*contexts):
return mod(*inputs)
if isinstance(cloned_inputs, Mapping):
return mod(**inputs)
else:
return mod(*inputs)


def forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
optimizer_zero_grad(optimizer, mod)
with nested(*contexts):
pred = mod(*cloned_inputs)
if isinstance(cloned_inputs, Mapping):
pred = mod(**cloned_inputs)
else:
pred = mod(*cloned_inputs)
loss = compute_loss(pred)
loss.backward(retain_graph=True)
optimizer_step(optimizer)
Expand Down

0 comments on commit 6fef32d

Please sign in to comment.