Skip to content

Commit

Permalink
Change the device_id for FSDP plugin (huggingface#1086)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckvermaAI authored Aug 1, 2024
1 parent da52af7 commit d9b02f0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": torch.device("hpu"),
"device_id": torch.device("hpu", torch.hpu.current_device()),
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __post_init__(self):
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.device("hpu")
device = torch.device("hpu", torch.hpu.current_device())
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)


Expand Down

0 comments on commit d9b02f0

Please sign in to comment.