Skip to content

Commit

Permalink
Update on "[PoC][MoE & EP] model code and various parallelisms"
Browse files Browse the repository at this point in the history
The expert-choice MoE implementation is mostly from torchtune: pytorch/torchtune#1902

Temporary changes to unblock exploration
- [pytorch] comment out the check at https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/api.py#L66
- [torchtitan] for dp2ep, turn optimizers `foreach` and `clip_grad_norm_` off, as not all parameters are DTensors on the same meshes (e.g. in dp2ep `moe.router.gate` is a replicate torch.Tensor)
- [torchtitan] for dp2ep, comment out `apply_fsdp` which would leave the non-expert parameters replicate

Todo
- FSDP / CP integration

Haven't worked on
- softmax scoring when Router Parallel is used (currently only sigmoid)
- token-choice MoE
- shared expert overlapping

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Dec 10, 2024
1 parent 4427dc2 commit 6ca7c10
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_optimizers(model_parts, job_config: JobConfig):
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
"foreach": False,
}

return (
Expand Down
23 changes: 12 additions & 11 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,18 @@ def parallelize_llama(
)

ep_mode = job_config.experimental.expert_parallel_mode
apply_ep(
model,
ep_mode=ep_mode,
dp_mesh=world_mesh["dp"] if parallel_dims.dp_shard_enabled else None,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
dp_tp_mesh=(
world_mesh["dp", "tp"]
if parallel_dims.dp_shard_enabled and parallel_dims.tp_enabled
else None
),
)
if ep_mode != "none":
apply_ep(
model,
ep_mode=ep_mode,
dp_mesh=world_mesh["dp"] if parallel_dims.dp_shard_enabled else None,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
dp_tp_mesh=(
world_mesh["dp", "tp"]
if parallel_dims.dp_shard_enabled and parallel_dims.tp_enabled
else None
),
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Expand Down
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ def loss_fn(pred, labels):
del pred
loss.backward()

# clip gradients
utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
job_config.training.max_norm,
foreach=True,
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
)
# # clip gradients
# utils.clip_grad_norm_(
# [p for m in model_parts for p in m.parameters()],
# job_config.training.max_norm,
# foreach=True,
# pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
# )

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
Expand Down

0 comments on commit 6ca7c10

Please sign in to comment.