Skip to content

Commit

Permalink
[Bugs] Fix bugs caused by sequence parallel when deepspeed is not use…
Browse files Browse the repository at this point in the history
…d. (#752)

* fix sp bugs when training wo deepspeed

* check deepspeed usage when setting sequence_parallel_size > 1
  • Loading branch information
HIT-cwh authored Jun 11, 2024
1 parent 4910476 commit a3e11b9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions xtuner/parallel/sequence/setup_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_sequence_parallel_world_size():
global _SEQUENCE_PARALLEL_WORLD_SIZE
if _SEQUENCE_PARALLEL_WORLD_SIZE is not None:
return _SEQUENCE_PARALLEL_WORLD_SIZE
if not dist.is_initialized():
if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None):
_SEQUENCE_PARALLEL_WORLD_SIZE = 1
else:
_SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
Expand All @@ -72,7 +72,7 @@ def get_sequence_parallel_rank():
global _SEQUENCE_PARALLEL_RANK
if _SEQUENCE_PARALLEL_RANK is not None:
return _SEQUENCE_PARALLEL_RANK
if not dist.is_initialized():
if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None):
_SEQUENCE_PARALLEL_RANK = 0
else:
_SEQUENCE_PARALLEL_RANK = dist.get_rank(
Expand Down
11 changes: 9 additions & 2 deletions xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def register_function(cfg_dict):
register_function(value)


def check_cfg(cfg):
def check_cfg(cfg, args):
if getattr(cfg, 'use_varlen_attn',
False) and cfg.train_dataloader.batch_size > 1:
raise NotImplementedError(
Expand Down Expand Up @@ -116,6 +116,13 @@ def check_cfg(cfg):
' attn_implementation to `flash_attention_2` or do not '
f'set this attribute. Got `{attn_implementation}` .')

if args.deepspeed is None:
assert getattr(cfg, 'sequence_parallel_size', 1) == 1, \
('Sequence parallel training without DeepSpeed lacks validation.'
'Please use DeepSpeed to optimize the training phase by '
'`--deepspeed deepspeed_zero1 (deepspeed_zero2 or '
'deepspeed_zero3)`.')


def main():
args = parse_args()
Expand All @@ -137,7 +144,7 @@ def main():
# change these FunctionType object to str
register_function(cfg._cfg_dict)

check_cfg(cfg)
check_cfg(cfg, args)

if cfg.get('framework', 'mmengine').lower() == 'huggingface':
# set default training_args
Expand Down

0 comments on commit a3e11b9

Please sign in to comment.