Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization in async tp-linear #208

Merged
merged 29 commits into from
Aug 5, 2024

Conversation

AleHD
Copy link
Contributor

@AleHD AleHD commented Jul 18, 2024

This PR introduces the memory optimization methods implemented in #203, but now allows them to be used in the async comm regime. This PR also includes commits related to fixing the row-parallel tp-linear in #172, so both of those PRs should be merged first to make reviewing this one easier. This PR also includes two modes: recomputing or not the all_gather. Recomputing is more-memory efficient, but slightly slower.

Here is a table that summarizes my observations on a tp4 llama8b model on A100 GPUs:

Method Average throughput (tok/sec/gpu) Max memory reserved (GB)
Baseline (current main implementation) 3415 72.6
Sync no-recompute 3525 (+3%) 68.5 (-6%)
Sync recompute 3427 61.9 (-15%)
Async no-recompute 3587 (+5%) 68.8 (-5%)
Async recompute 3526 (+3%) 60.9 (-16%)

These changes should prove to be very useful for more efficient training. I recommend using the async-recompute setting, but using async-norecompute might make more sense for the extra throughput when memory is not a concern. In addition, as dp and pp increase and optimizer states and parameters become more sharded, the memory savings this PR brings should only increase as these affect memory activations. Very useful for scaling to larger models.

I attach the wandb logs for llama8b (top) experiments and a tiny 152M model (bottom) to study the effects on smaller models.

image
image
(blue = sync baseline, green = sync no-recompute, purple = sync recompute, yellow = async baseline, gray = async recompute, red = async no-recompute).

@@ -141,22 +142,27 @@ def forward(ctx, tensor, weight, bias, group, tp_mode):
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
ctx.save_for_backward(tensor, weight)
# ctx.save_for_backward(tensor, weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments

@xrsrke xrsrke self-requested a review July 29, 2024 16:10
@AleHD AleHD marked this pull request as draft July 30, 2024 08:41
@AleHD AleHD marked this pull request as ready for review July 30, 2024 17:43
Copy link
Member

@xrsrke xrsrke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@3outeille
Copy link
Member

lgtm as well

@3outeille 3outeille merged commit 03d67f2 into huggingface:main Aug 5, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants