diff --git a/torchbenchmark/models/sam_fast/__init__.py b/torchbenchmark/models/sam_fast/__init__.py new file mode 100644 index 0000000000..f4a8aeadfd --- /dev/null +++ b/torchbenchmark/models/sam_fast/__init__.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from ...util.model import BenchmarkModel +from segment_anything_fast.build_sam import sam_model_fast_registry +from segment_anything_fast.predictor import SamPredictor +import numpy as np +import cv2 +from torchbenchmark.tasks import COMPUTER_VISION +import torch +import os + + +class Model(BenchmarkModel): + task = COMPUTER_VISION.SEGMENTATION + DEFAULT_EVAL_BSIZE = 32 + + def __init__(self, test, device, batch_size=1, extra_args=[]): + super().__init__( + test=test, device=device, batch_size=batch_size, extra_args=extra_args + ) + + # Checkpoint options are here https://github.com/facebookresearch/segment-anything#model-checkpoints + data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data") + sam_checkpoint = os.path.join(data_folder, "sam_vit_h_4b8939.pth") + model_type = "vit_h" + + self.model = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) + self.model.to(device=device) + data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data") + + image_path = os.path.join(data_folder, "truck.jpg") + self.image = cv2.imread(image_path) + self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + self.sample_image = torch.randn((3, 256, 256)).to(device) + + def get_module(self): + example_input = [ + { + "image": self.sample_image, + "original_size": (256, 256), + "point_coords": torch.tensor([[[1,1]]], device=self.device), + "point_labels": torch.tensor([[1]], device=self.device), + } + ] + + multimask_output = False + return self.model, (example_input, multimask_output) + + def train(self): + error_msg = """ + As of May 17, 2023 + Some base VIT checkpoints are available for SAM but getting the dataset + requires a research license. It's easy to make up a training loop on random + data and if that's interesting please let @msaroufim know + https://github.com/facebookresearch/segment-anything#dataset + """ + return NotImplementedError(error_msg) + + def eval(self): + # To test for bfloat16 uncomment the below line + # predictor = SamPredictor(self.model.to(dtype=torch.bfloat16)) + + predictor = SamPredictor(self.model) + + predictor.set_image(self.image) + + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + masks, scores, logits = predictor.predict( + point_coords=input_point, point_labels=input_label, multimask_output=True + ) + return (masks,) diff --git a/torchbenchmark/models/sam_fast/install.py b/torchbenchmark/models/sam_fast/install.py new file mode 100644 index 0000000000..0646de166e --- /dev/null +++ b/torchbenchmark/models/sam_fast/install.py @@ -0,0 +1,23 @@ +import os +import subprocess +import sys + +def pip_install_requirements(): + subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) + +def download_checkpoint(): + subprocess.check_call(['wget', '-P', '.data', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth']) + +def download_data(): + subprocess.check_call(['wget', '-P', '.data', 'https://github.com/facebookresearch/segment-anything/raw/main/notebooks/images/truck.jpg']) + +if __name__ == '__main__': + pip_install_requirements() + + # Create .data folder in the script's directory + data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.data') + os.makedirs(data_folder, exist_ok=True) + + # Download checkpoint and data files to the .data folder + download_checkpoint() + download_data() diff --git a/torchbenchmark/models/sam_fast/metadata.yaml b/torchbenchmark/models/sam_fast/metadata.yaml new file mode 100644 index 0000000000..3968b5ea5e --- /dev/null +++ b/torchbenchmark/models/sam_fast/metadata.yaml @@ -0,0 +1,12 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 32 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false +not_implemented: +- device: cpu +- device: cuda + test: example diff --git a/torchbenchmark/models/sam_fast/requirements.txt b/torchbenchmark/models/sam_fast/requirements.txt new file mode 100644 index 0000000000..7e8413141a --- /dev/null +++ b/torchbenchmark/models/sam_fast/requirements.txt @@ -0,0 +1,3 @@ +git+https://github.com/pytorch-labs/segment-anything-fast.git +opencv-python +pycocotools diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index 956fdb4f98..649f07f899 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -25,6 +25,7 @@ "pytorch_CycleGAN_and_pix2pix", "pytorch_unet", "sam", + "sam_fast", "Super_SloMo", "vgg16", "mtml_ctr_instagram_model",