Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 17, 2024
2 parents b889f5b + 501a2fb commit 7274bd9
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
Expand Down Expand Up @@ -240,10 +239,6 @@ def loss_fn(pred, labels):
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)

if isinstance(model, FSDPModule) and parallel_dims.pp_enabled:
# reshard now to counteract an issue where FSDP's states got advanced during PP stage shape inference
model.reshard()

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)
Expand Down

0 comments on commit 7274bd9

Please sign in to comment.