Skip to content

Commit

Permalink
Register grid buffers for Super_SloMo. (#2038)
Browse files Browse the repository at this point in the history
Summary:
This PR registers both `gridX` and `gridY` as buffers for the `backWarp` module, which belongs to the `Super_SloMo` benchmark.

Without this PR, those buffers wouldn't be moved to a device on `model.to(device)`. This causes problems, specifically with XLA.

Pull Request resolved: #2038

Test Plan:
```
PJRT_DEVICE=CPU python benchmarks/dynamo/torchbench.py --performance --trace-on-xla --backend openxla --inference --only Super_SloMo
```

Reviewed By: aaronenyeshi

Differential Revision: D51367773

Pulled By: xuzhao9

fbshipit-source-id: bf05b5320c9dd4f46ffa466611bdfee0c20ee90f
  • Loading branch information
ysiraichi authored and facebook-github-bot committed Nov 15, 2023
1 parent c7cb5aa commit fa41cb2
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchbenchmark/models/Super_SloMo/slomo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ def __init__(self, W, H, device):

# Use torch.meshgrid instead of np.meshgrid to imrpove performance
# https://github.com/avinashpaliwal/Super-SloMo/pull/111
self.gridX, self.gridY = torch.meshgrid(torch.arange(W, requires_grad=False, device=device),
torch.arange(H, requires_grad=False, device=device), indexing='xy')
gridX, gridY = torch.meshgrid(torch.arange(W, requires_grad=False, device=device),
torch.arange(H, requires_grad=False, device=device), indexing='xy')
self.register_buffer("gridX", gridX)
self.register_buffer("gridY", gridY)

def forward(self, img, flow):
"""
Expand Down

0 comments on commit fa41cb2

Please sign in to comment.