Skip to content

Commit

Permalink
Add sam_fast torchbench (#2182)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2182

Took the same benchmark used by torchbench/sam, just using the https://github.com/pytorch-labs/segment-anything-fast/tree/main/segment_anything_fast registry instead which invokes dynamo and inductor.

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D54589585

Pulled By: jamesjwu

fbshipit-source-id: 379e7775db9155a70976ae95144e6e0edae5e168
  • Loading branch information
jamesjwu authored and facebook-github-bot committed Mar 6, 2024
1 parent 611d446 commit d6015d4
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 0 deletions.
73 changes: 73 additions & 0 deletions torchbenchmark/models/sam_fast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from ...util.model import BenchmarkModel
from segment_anything_fast.build_sam import sam_model_fast_registry
from segment_anything_fast.predictor import SamPredictor
import numpy as np
import cv2
from torchbenchmark.tasks import COMPUTER_VISION
import torch
import os


class Model(BenchmarkModel):
task = COMPUTER_VISION.SEGMENTATION
DEFAULT_EVAL_BSIZE = 32

def __init__(self, test, device, batch_size=1, extra_args=[]):
super().__init__(
test=test, device=device, batch_size=batch_size, extra_args=extra_args
)

# Checkpoint options are here https://github.com/facebookresearch/segment-anything#model-checkpoints
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data")
sam_checkpoint = os.path.join(data_folder, "sam_vit_h_4b8939.pth")
model_type = "vit_h"

self.model = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
self.model.to(device=device)
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data")

image_path = os.path.join(data_folder, "truck.jpg")
self.image = cv2.imread(image_path)
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.sample_image = torch.randn((3, 256, 256)).to(device)

def get_module(self):
example_input = [
{
"image": self.sample_image,
"original_size": (256, 256),
"point_coords": torch.tensor([[[1,1]]], device=self.device),
"point_labels": torch.tensor([[1]], device=self.device),
}
]

multimask_output = False
return self.model, (example_input, multimask_output)

def train(self):
error_msg = """
As of May 17, 2023
Some base VIT checkpoints are available for SAM but getting the dataset
requires a research license. It's easy to make up a training loop on random
data and if that's interesting please let @msaroufim know
https://github.com/facebookresearch/segment-anything#dataset
"""
return NotImplementedError(error_msg)

def eval(self):
# To test for bfloat16 uncomment the below line
# predictor = SamPredictor(self.model.to(dtype=torch.bfloat16))

predictor = SamPredictor(self.model)

predictor.set_image(self.image)

input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point, point_labels=input_label, multimask_output=True
)
return (masks,)
23 changes: 23 additions & 0 deletions torchbenchmark/models/sam_fast/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import subprocess
import sys

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

def download_checkpoint():
subprocess.check_call(['wget', '-P', '.data', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'])

def download_data():
subprocess.check_call(['wget', '-P', '.data', 'https://github.com/facebookresearch/segment-anything/raw/main/notebooks/images/truck.jpg'])

if __name__ == '__main__':
pip_install_requirements()

# Create .data folder in the script's directory
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.data')
os.makedirs(data_folder, exist_ok=True)

# Download checkpoint and data files to the .data folder
download_checkpoint()
download_data()
12 changes: 12 additions & 0 deletions torchbenchmark/models/sam_fast/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- device: cpu
- device: cuda
test: example
3 changes: 3 additions & 0 deletions torchbenchmark/models/sam_fast/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+https://github.com/pytorch-labs/segment-anything-fast.git
opencv-python
pycocotools
1 change: 1 addition & 0 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"pytorch_CycleGAN_and_pix2pix",
"pytorch_unet",
"sam",
"sam_fast",
"Super_SloMo",
"vgg16",
"mtml_ctr_instagram_model",
Expand Down

0 comments on commit d6015d4

Please sign in to comment.