Skip to content

Commit

Permalink
FIX SAM for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Jul 13, 2023
1 parent 2ea018e commit fcaf110
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
3 changes: 1 addition & 2 deletions torchbenchmark/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_module(self):
]

multimask_output = False

return self.model, (example_input, multimask_output)

def train(self):
Expand All @@ -57,7 +56,7 @@ def train(self):
return NotImplementedError(error_msg)

def eval(self):
predictor = SamPredictor(self.model)
predictor = SamPredictor(self.model.to(dtype=torch.bfloat16))

predictor.set_image(self.image)

Expand Down
1 change: 1 addition & 0 deletions torchbenchmark/models/sam/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def predict_masks(
b, c, h, w = src.shape

# Run the transformer
tokens = tokens.to(src.dtype)
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/models/sam/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def predict(
)

masks_np = masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].to(torch.float32).detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].to(torch.float32).detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np

@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions torchbenchmark/models/sam/prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)

coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
Expand Down

0 comments on commit fcaf110

Please sign in to comment.