Skip to content

Commit

Permalink
[TP][2D][EZ] Fix Error in FSDP 2D test (pytorch#107975)
Browse files Browse the repository at this point in the history
As title, TP dimension should be the second dim, so we need to pass tp_degree to the second rather the first dim of the mesh tensor.

Pull Request resolved: pytorch#107975
Approved by: https://github.com/wz337, https://github.com/awgu
  • Loading branch information
fduwjj authored and pytorchmergebot committed Aug 25, 2023
1 parent 08e49fe commit 7ef13b1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def init_model(
# 2-D mesh is [dp, tp]
twod_mesh = DeviceMesh(
device_type="cuda",
mesh=torch.arange(0, world_size).view(model_parallel_size, -1),
mesh=torch.arange(0, world_size).view(-1, model_parallel_size),
)

fsdp_pg = twod_mesh.get_dim_groups()[0]
Expand Down

0 comments on commit 7ef13b1

Please sign in to comment.