diff --git a/test_runner.py b/test_runner.py index 7aa3609d..82ec9c4f 100755 --- a/test_runner.py +++ b/test_runner.py @@ -313,10 +313,22 @@ def build_test_list(): [ [ "--experimental.context_parallel_degree=4", + "--experimental.context_parallel_rotate_method='allgather'", ] ], - "CP", - "cp", + "CP (allgather)", + "cp allgather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--experimental.context_parallel_degree=4", + "--experimental.context_parallel_rotate_method='alltoall'", + ] + ], + "CP (alltoall)", + "cp alltoall", ngpu=4, ), OverrideDefinitions( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e7bca6f1..c530e4a5 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -355,6 +355,20 @@ def __init__(self): default=1, help="Context parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--experimental.context_parallel_rotate_method", + type=str, + default="allgather", + help=""" + The collective to use in context parallel SDPA for kv shards exchange. + + 'allgather' means to all-gather all kv shards on ranks, + + 'alltoall' means to all-to-all shuffle the kv shards. + + The default value is 'allgather'. + """, + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/utils.py b/torchtitan/utils.py index f663cc5c..65a4a4bd 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -126,15 +126,18 @@ def create_context_parallel_ctx( cp_buffers: List[torch.Tensor], cp_seq_dims: List[int], cp_no_restore_buffers: Set[torch.Tensor], + cp_rotate_method: str, ): try: from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method except ImportError: print( f"PyTorch version {torch.__version__} does not include the experimental " "Context Parallel API. Please update to a newer version." ) + set_rotate_method(cp_rotate_method) return context_parallel( cp_mesh, buffers=cp_buffers, diff --git a/train.py b/train.py index 9e8b1fa8..6397f3e1 100644 --- a/train.py +++ b/train.py @@ -280,6 +280,7 @@ def loss_fn(pred, labels): cp_buffers=[input_ids, labels, model.freqs_cis], cp_seq_dims=[1, 1, 0], cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method=job_config.experimental.context_parallel_rotate_method, ) if parallel_dims.cp_enabled else None