From fcaf1103d297a57207921028bcffa4a5b9a07248 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 13 Jul 2023 01:58:13 +0000 Subject: [PATCH] FIX SAM for bfloat16 --- torchbenchmark/models/sam/__init__.py | 3 +-- torchbenchmark/models/sam/mask_decoder.py | 1 + torchbenchmark/models/sam/predictor.py | 4 ++-- torchbenchmark/models/sam/prompt_encoder.py | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torchbenchmark/models/sam/__init__.py b/torchbenchmark/models/sam/__init__.py index d11881cf41..214f641e7a 100644 --- a/torchbenchmark/models/sam/__init__.py +++ b/torchbenchmark/models/sam/__init__.py @@ -43,7 +43,6 @@ def get_module(self): ] multimask_output = False - return self.model, (example_input, multimask_output) def train(self): @@ -57,7 +56,7 @@ def train(self): return NotImplementedError(error_msg) def eval(self): - predictor = SamPredictor(self.model) + predictor = SamPredictor(self.model.to(dtype=torch.bfloat16)) predictor.set_image(self.image) diff --git a/torchbenchmark/models/sam/mask_decoder.py b/torchbenchmark/models/sam/mask_decoder.py index 5d2fdb03d5..10471bd190 100644 --- a/torchbenchmark/models/sam/mask_decoder.py +++ b/torchbenchmark/models/sam/mask_decoder.py @@ -129,6 +129,7 @@ def predict_masks( b, c, h, w = src.shape # Run the transformer + tokens = tokens.to(src.dtype) hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] diff --git a/torchbenchmark/models/sam/predictor.py b/torchbenchmark/models/sam/predictor.py index 24920c046b..3d7aee74fd 100644 --- a/torchbenchmark/models/sam/predictor.py +++ b/torchbenchmark/models/sam/predictor.py @@ -160,8 +160,8 @@ def predict( ) masks_np = masks[0].detach().cpu().numpy() - iou_predictions_np = iou_predictions[0].detach().cpu().numpy() - low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].to(torch.float32).detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].to(torch.float32).detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad() diff --git a/torchbenchmark/models/sam/prompt_encoder.py b/torchbenchmark/models/sam/prompt_encoder.py index c3143f4f8e..70c9a8267a 100644 --- a/torchbenchmark/models/sam/prompt_encoder.py +++ b/torchbenchmark/models/sam/prompt_encoder.py @@ -186,6 +186,8 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 + coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) + coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape