Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tp mem cache #203

Merged
merged 20 commits into from
Aug 2, 2024
Merged

Fix tp mem cache #203

merged 20 commits into from
Aug 2, 2024

Conversation

AleHD
Copy link
Contributor

@AleHD AleHD commented Jun 28, 2024

Nanotron seems to consume disproportionately more memory on its activations compared to megatron. This is due to at least the following factors:

  • The glu activation, which is not fused, allocate two tensors: the activation and the element-wise multiplication. Fusing this operation provides (relatively small) memory gains.
  • During the differentiable operations (specifically, the DifferentiableAllGather and DifferentiableReduceScatterSum) a new tensor is allocated via torch.empty. This tensor is cached through the entire forward pass until the backward pass. However, this cache is unnecessary as these tensors are not used at all during the backward. Getting rid of these allocations provide significant memory gains. To fix this, this PR introduces a global memory buffer (MemoryBuffer singleton) that recycles the allocated spaces, similar to megatron.

Attached: Memory traces of the default nanotron implementation (which OOMs), the current PR implementation and megatron. The memory traces represent the first rank of a tp8 pp4 dp1 llama70b 5 iteration run (sequence length 8k, microbatch size of 1, accumulation=4, synchronous tp and reduc_scatter mode).

image
image
image

I think these changes are important, as it allows training larger models with significantly less memory requirements.

Let me know if you have any suggestions, and I'd be happy to make adjustments to upstream this feature! :)

@3outeille 3outeille self-assigned this Jun 28, 2024
@3outeille 3outeille requested a review from xrsrke June 28, 2024 13:29
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass device=tensor.device and requires_grad=tensor.requires_grad as well ?

dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
sharded_tensor = MemoryBuffer().get(
"dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@3outeille
Copy link
Member

nice catch ! lgtm

@AleHD
Copy link
Contributor Author

AleHD commented Jun 28, 2024

Thanks for the comments! I have one question regarding the requires_grad. In principle we shouldn't require gradient on the gathered tensors, right? The custom backward handles the gradient computation for the parameters anyway. At least training runs seamlessly without setting gradient to those tensors.

@3outeille
Copy link
Member

Yeah I think we can set it to False here (seems like Megatron does the same as well here)

@zzhhjjj
Copy link
Collaborator

zzhhjjj commented Jul 2, 2024

Hi,

Thanks for the PR, it's really nice! I tested your PR by training 100 steps on the Tiny Story dataset and compared the loss with our code. I found an abnormal difference. Could you observe the same thing on your side? This is my config file, you may have to change it a little bit, but the idea is to compare the loss before and after the change with the same hyperparameters. Thanks a lot for the work.

checkpoints:
  checkpoint_interval: 1000
  checkpoints_path: null
  checkpoints_path_is_shared_file_system: false
  resume_checkpoint_path: null
  save_initial_state: false

data_stages:
- data:
    dataset:
      dataset_overwrite_cache: false
      dataset_processing_num_proc_per_process: 1
      hf_dataset_config_name: null
      hf_dataset_or_datasets: roneneldan/TinyStories
      hf_dataset_splits: train
      text_column_name: text
    num_loading_workers: 0
    seed: 42
  name: Stable Training Stage
  start_training_step: 1

general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: true
  project: ring-attention
  run: Llama3-128k-lr-1e-5-5000
  seed: 42
  step: null

lighteval: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info

model:
  ddp_bucket_cap_mb: 25
  dtype: bfloat16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 768
    initializer_range: 0.02
    intermediate_size: 3072
    is_llama_config: true
    max_position_embeddings: 512
    num_attention_heads: 16
    num_hidden_layers: 12
    num_key_value_heads: 16
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_theta: 1000000.0
    rope_interleaved: false
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 50272

optimizer:
  accumulate_grad_in_fp32: true
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_starting_step: null
    lr_decay_steps: 198
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  optimizer_factory:
    adam_beta1: 0.9
    adam_beta2: 0.95
    adam_eps: 1.0e-08
    name: adamW
    torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0

tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2 
  tokenizer_revision: null

parallelism:
  dp: 1
  expert_parallel_size: 1
  pp: 1
  pp_engine: 1f1b
  tp: 4
  tp_linear_async_communication: false
  tp_mode: REDUCE_SCATTER
profiler: null

tokens:
  batch_accumulation_per_replica: 2
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 64
  sequence_length: 512
  train_steps: 100
  val_check_interval: -1

s3_upload: null
experiment_logger: null

image

@AleHD
Copy link
Contributor Author

AleHD commented Jul 8, 2024

That's interesting. Thanks for letting me know, I will investigate further and come back with some results soon :)

@3outeille 3outeille reopened this Jul 15, 2024
@AleHD
Copy link
Contributor Author

AleHD commented Jul 17, 2024

I was able to reproduce the error. To fix the issue, I followed megatron's design and fused the all gather and linear operation in a single module. The loss progression now matches with the main branch.

I added two configurations of the optimization. Given that the all-gather and linear are in the same module, we can control whether to recompute or not (and cache it instead) the all-gather during the backward. As expected, recomputing yields larger memory savings at the cost of throughput (but still, both methods are more memory-efficient than the current implementation, and both provide at least comparable tok/sec than the current main). The configuration is parallelism.tp_recompute_allgather. I chose to set it to true as default because it provides the best trade-off, but it is probably advised to set it to true if memory is not a concern.

I attach wandb logs of four runs on two different configurations that validate the claims. On both, blue is the baseline (main branch) implementation, red is the wrong first version of this PR, green is the no-recompute mode (moderate memory savings and slightly faster than baseline) and purple is the recompute mode (large memory savings and on average as fast as the baseline). The first plots correspond to the tiny llama configuration you shared before. The second plot corresponds to a llama8b run. Except for the wrong plot, all lines are pretty much identical in the lm_loss graph.
image
image

Let me know if you have any suggestions.

@AleHD AleHD marked this pull request as draft July 30, 2024 08:41
@AleHD AleHD marked this pull request as ready for review July 30, 2024 17:44
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like dist.reduce_scatter needs grad_input to be contiguous (cf https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305)

I am not sure if grad_input = grad_output @ weight is contiguous (although you have grad_output = grad_output.contiguous()). Maybe to be sure, we should grad_input = grad_input.contiguous() before running the reduce_scatter ? what do you think ?

@3outeille
Copy link
Member

I chose to set it to true as default because it provides the best trade-off, but it is probably advised to set it to true if memory is not a concern.

Hello, very nice PR ! I left some comments as well before merging

I was able to reproduce the results but found out that the version with recompute_all_gather is faster than the baseline (TP of main branch). So I am okay with setting it to True by default as well.

image

@AleHD
Copy link
Contributor Author

AleHD commented Aug 2, 2024

Updated the PR with the suggestions mentioned! Let me know if I'm missing something.

@3outeille
Copy link
Member

all points were addressed, LGTM !

@3outeille 3outeille merged commit 4eb520f into huggingface:main Aug 2, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants