From 612b3c8c2b81b8da3506aeeccc9e1162e09ca904 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 19 Jun 2024 10:21:44 -0700 Subject: [PATCH] Add XLA support to `moco` benchmark. (#2292) Summary: This PR tweaks `moco` benchmark, so that it will also run on XLA devices. Previously, `moco` hardcoded the CUDA device in two ways: - Initializing the `ProcessGroup` with `nccl` backend, only - Moving intermediate tensors to `cuda`, explicitly In order to add XLA support, this PR: - Also checks for `xla*` devices and, if detected, initializes the `ProcessGroup` with `xla` backend - Moves intermediate tensors to the appropriate devices cc lezcano Pull Request resolved: https://github.com/pytorch/benchmark/pull/2292 Reviewed By: aaronenyeshi Differential Revision: D58787062 Pulled By: xuzhao9 fbshipit-source-id: 545d2c71296cc3e80958a51b3e48b335a1a72b59 --- torchbenchmark/models/moco/__init__.py | 33 ++++++++++++++-------- torchbenchmark/models/moco/moco/builder.py | 4 +-- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/torchbenchmark/models/moco/__init__.py b/torchbenchmark/models/moco/__init__.py index d6290a895e..1faca2cf48 100644 --- a/torchbenchmark/models/moco/__init__.py +++ b/torchbenchmark/models/moco/__init__.py @@ -56,18 +56,29 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): "distributed": True, } ) - try: - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:10001", - world_size=1, - rank=0, - ) - except RuntimeError: - pass # already initialized? if device == "cpu": raise NotImplementedError("DistributedDataParallel/allgather requires cuda") + elif device == "cuda": + try: + dist.init_process_group( + backend="nccl", + init_method="tcp://localhost:10001", + world_size=1, + rank=0, + ) + except RuntimeError: + pass # already initialized? + elif device == "xla": + import torch_xla.distributed.xla_backend + + try: + dist.init_process_group(backend="xla", init_method="xla://") + except RuntimeError: + pass # already initialized? + else: + raise NotImplementedError(f"{device} not supported") + self.model = MoCo( models.__dict__[self.opt.arch], @@ -102,8 +113,8 @@ def collate_train_fn(data): range(2), collate_fn=collate_train_fn ) for i, (images, _) in enumerate(self.example_inputs): - images[0] = images[0].cuda(device=0, non_blocking=True) - images[1] = images[1].cuda(device=0, non_blocking=True) + images[0] = images[0].to(device, non_blocking=True) + images[1] = images[1].to(device, non_blocking=True) def get_module(self): """Recommended diff --git a/torchbenchmark/models/moco/moco/builder.py b/torchbenchmark/models/moco/moco/builder.py index 295e22a7e7..a0a0329134 100644 --- a/torchbenchmark/models/moco/moco/builder.py +++ b/torchbenchmark/models/moco/moco/builder.py @@ -79,7 +79,7 @@ def _batch_shuffle_ddp(self, x): num_gpus = batch_size_all // batch_size_this # random shuffle index - idx_shuffle = torch.randperm(batch_size_all).cuda() + idx_shuffle = torch.randperm(batch_size_all, device=x_gather.device) # broadcast to all gpus torch.distributed.broadcast(idx_shuffle, src=0) @@ -152,7 +152,7 @@ def forward(self, im_q, im_k): logits /= self.T # labels: positive key indicators - labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() + labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device) # dequeue and enqueue self._dequeue_and_enqueue(k)