-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #2182 Took the same benchmark used by torchbench/sam, just using the https://github.com/pytorch-labs/segment-anything-fast/tree/main/segment_anything_fast registry instead which invokes dynamo and inductor. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D54589585 Pulled By: jamesjwu fbshipit-source-id: 379e7775db9155a70976ae95144e6e0edae5e168
- Loading branch information
1 parent
611d446
commit d6015d4
Showing
5 changed files
with
112 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
from ...util.model import BenchmarkModel | ||
from segment_anything_fast.build_sam import sam_model_fast_registry | ||
from segment_anything_fast.predictor import SamPredictor | ||
import numpy as np | ||
import cv2 | ||
from torchbenchmark.tasks import COMPUTER_VISION | ||
import torch | ||
import os | ||
|
||
|
||
class Model(BenchmarkModel): | ||
task = COMPUTER_VISION.SEGMENTATION | ||
DEFAULT_EVAL_BSIZE = 32 | ||
|
||
def __init__(self, test, device, batch_size=1, extra_args=[]): | ||
super().__init__( | ||
test=test, device=device, batch_size=batch_size, extra_args=extra_args | ||
) | ||
|
||
# Checkpoint options are here https://github.com/facebookresearch/segment-anything#model-checkpoints | ||
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data") | ||
sam_checkpoint = os.path.join(data_folder, "sam_vit_h_4b8939.pth") | ||
model_type = "vit_h" | ||
|
||
self.model = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) | ||
self.model.to(device=device) | ||
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data") | ||
|
||
image_path = os.path.join(data_folder, "truck.jpg") | ||
self.image = cv2.imread(image_path) | ||
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) | ||
self.sample_image = torch.randn((3, 256, 256)).to(device) | ||
|
||
def get_module(self): | ||
example_input = [ | ||
{ | ||
"image": self.sample_image, | ||
"original_size": (256, 256), | ||
"point_coords": torch.tensor([[[1,1]]], device=self.device), | ||
"point_labels": torch.tensor([[1]], device=self.device), | ||
} | ||
] | ||
|
||
multimask_output = False | ||
return self.model, (example_input, multimask_output) | ||
|
||
def train(self): | ||
error_msg = """ | ||
As of May 17, 2023 | ||
Some base VIT checkpoints are available for SAM but getting the dataset | ||
requires a research license. It's easy to make up a training loop on random | ||
data and if that's interesting please let @msaroufim know | ||
https://github.com/facebookresearch/segment-anything#dataset | ||
""" | ||
return NotImplementedError(error_msg) | ||
|
||
def eval(self): | ||
# To test for bfloat16 uncomment the below line | ||
# predictor = SamPredictor(self.model.to(dtype=torch.bfloat16)) | ||
|
||
predictor = SamPredictor(self.model) | ||
|
||
predictor.set_image(self.image) | ||
|
||
input_point = np.array([[500, 375]]) | ||
input_label = np.array([1]) | ||
masks, scores, logits = predictor.predict( | ||
point_coords=input_point, point_labels=input_label, multimask_output=True | ||
) | ||
return (masks,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
|
||
def pip_install_requirements(): | ||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
||
def download_checkpoint(): | ||
subprocess.check_call(['wget', '-P', '.data', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth']) | ||
|
||
def download_data(): | ||
subprocess.check_call(['wget', '-P', '.data', 'https://github.com/facebookresearch/segment-anything/raw/main/notebooks/images/truck.jpg']) | ||
|
||
if __name__ == '__main__': | ||
pip_install_requirements() | ||
|
||
# Create .data folder in the script's directory | ||
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.data') | ||
os.makedirs(data_folder, exist_ok=True) | ||
|
||
# Download checkpoint and data files to the .data folder | ||
download_checkpoint() | ||
download_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
devices: | ||
NVIDIA A100-SXM4-40GB: | ||
eval_batch_size: 32 | ||
eval_benchmark: false | ||
eval_deterministic: false | ||
eval_nograd: true | ||
train_benchmark: false | ||
train_deterministic: false | ||
not_implemented: | ||
- device: cpu | ||
- device: cuda | ||
test: example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
git+https://github.com/pytorch-labs/segment-anything-fast.git | ||
opencv-python | ||
pycocotools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters