From fa41cb26a39c112f72f55b3ccacc2bbc502e2649 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 15 Nov 2023 13:48:00 -0800 Subject: [PATCH] Register grid buffers for Super_SloMo. (#2038) Summary: This PR registers both `gridX` and `gridY` as buffers for the `backWarp` module, which belongs to the `Super_SloMo` benchmark. Without this PR, those buffers wouldn't be moved to a device on `model.to(device)`. This causes problems, specifically with XLA. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2038 Test Plan: ``` PJRT_DEVICE=CPU python benchmarks/dynamo/torchbench.py --performance --trace-on-xla --backend openxla --inference --only Super_SloMo ``` Reviewed By: aaronenyeshi Differential Revision: D51367773 Pulled By: xuzhao9 fbshipit-source-id: bf05b5320c9dd4f46ffa466611bdfee0c20ee90f --- torchbenchmark/models/Super_SloMo/slomo_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchbenchmark/models/Super_SloMo/slomo_model.py b/torchbenchmark/models/Super_SloMo/slomo_model.py index a06e51eb74..60160c18ef 100644 --- a/torchbenchmark/models/Super_SloMo/slomo_model.py +++ b/torchbenchmark/models/Super_SloMo/slomo_model.py @@ -249,8 +249,10 @@ def __init__(self, W, H, device): # Use torch.meshgrid instead of np.meshgrid to imrpove performance # https://github.com/avinashpaliwal/Super-SloMo/pull/111 - self.gridX, self.gridY = torch.meshgrid(torch.arange(W, requires_grad=False, device=device), - torch.arange(H, requires_grad=False, device=device), indexing='xy') + gridX, gridY = torch.meshgrid(torch.arange(W, requires_grad=False, device=device), + torch.arange(H, requires_grad=False, device=device), indexing='xy') + self.register_buffer("gridX", gridX) + self.register_buffer("gridY", gridY) def forward(self, img, flow): """