From d9b02f070be486fa5fa51d5f2543625d46158501 Mon Sep 17 00:00:00 2001 From: Chetan Kumar Verma <39086835+ckvermaAI@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:21:04 +0530 Subject: [PATCH] Change the device_id for FSDP plugin (#1086) --- optimum/habana/accelerate/accelerator.py | 2 +- optimum/habana/accelerate/utils/dataclasses.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index d826908ee5..f324aebd6a 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -412,7 +412,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "param_init_fn": fsdp_plugin.param_init_fn, "ignored_modules": fsdp_plugin.ignored_modules, "limit_all_gathers": fsdp_plugin.limit_all_gathers, - "device_id": torch.device("hpu"), + "device_id": torch.device("hpu", torch.hpu.current_device()), } model = FSDP(model, **kwargs) if fsdp_plugin.activation_checkpointing: diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index fce2c06c8c..1db6980ee7 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -142,7 +142,7 @@ def __post_init__(self): self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 if self.sync_module_states: - device = torch.device("hpu") + device = torch.device("hpu", torch.hpu.current_device()) self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)