Skip to content

Commit

Permalink
[cp] add option to choose kv shards rotation method
Browse files Browse the repository at this point in the history
ghstack-source-id: 030f53c4e1520715f29f35f8cacbe3f1c939b9e4
Pull Request resolved: #684
  • Loading branch information
XilunWu committed Dec 5, 2024
1 parent 7341d80 commit d9dcfda
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
16 changes: 14 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,22 @@ def build_test_list():
[
[
"--experimental.context_parallel_degree=4",
"--experimental.context_parallel_rotate_method='allgather'",
]
],
"CP",
"cp",
"CP (allgather)",
"cp allgather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.context_parallel_degree=4",
"--experimental.context_parallel_rotate_method='alltoall'",
]
],
"CP (alltoall)",
"cp alltoall",
ngpu=4,
),
OverrideDefinitions(
Expand Down
14 changes: 14 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,20 @@ def __init__(self):
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--experimental.context_parallel_rotate_method",
type=str,
default="allgather",
help="""
The collective to use in context parallel SDPA for kv shards exchange.
'allgather' means to all-gather all kv shards on ranks,
'alltoall' means to all-to-all shuffle the kv shards.
The default value is 'allgather'.
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,18 @@ def create_context_parallel_ctx(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
cp_rotate_method: str,
):
try:
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)

set_rotate_method(cp_rotate_method)
return context_parallel(
cp_mesh,
buffers=cp_buffers,
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def loss_fn(pred, labels):
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_no_restore_buffers={input_ids, labels},
cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
Expand Down

0 comments on commit d9dcfda

Please sign in to comment.