Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update theoretical memory footprint formula #1345

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

okoge-kaz
Copy link

number of parameters

- + (2 / args.hidden_size)
+ + (1 / args.hidden_size)

Because there is 2 * args.num_layers * args.hidden_size * args.hidden_size, the coefficient of layernorms should be 1, not 2.

final layernorm

- (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size)
+ (num_parameters_in_transformer_layers - args.hidden_size) / args.pipeline_model_parallel_size

Since the final layernorm is not relevant for stages other than the last pipeline stage, it is necessary to subtract the number of parameters for the final layernorm using - args.hidden_size.

gradients' type

- 6 + (12 / args.data_parallel_size)
+ (2 + gradient_accumulation_factor) + (12 / args.data_parallel_size

When args.accumulate_allreduce_grads_in_fp32 is True, the coefficient can be set to 6 from parameters(bf16) + gradients(fp32), but when it is False, it becomes 4, so case separation is necessary.

flash attention

- if not args.sequence_parallel or args.recompute_granularity != 'selective':
+ if not args.sequence_parallel or not (
+      args.recompute_granularity == 'selective' or args.use_flash_attn is True
+  ):

Since it is possible to calculate not only when args.recompute_granularity is selective, but also when args.use_flash_attn is True, it is added to the conditions.

SwiGLU

            (
                # SwiGLU
                2 * b * s * h  # input
                + 2 * b * s * args.ffn_hidden_size  # up_proj
                + 2 * b * s * args.ffn_hidden_size  # gate_proj
                + 2 * b * s * args.ffn_hidden_size  # act_fn
                + 2 * b * s * args.ffn_hidden_size  # down_proj
            ) if args.swiglu else (
                2 * b * s * h  # h -> ffn_h
                + 2 * b * s * args.ffn_hidden_size  # act
                + 2 * b * s * args.ffn_hidden_size  # ffn_h  -> h
            )

Added conditional branching to support both GPT and Llama architectures

other changes

  • Theoritical memory footprint is easier to read in GB units, so I changed it from MB to GB.
  • Change so that the theoretical memory footprint per GPU when CP (Context Parallelism) is enabled can be correctly calculated.
  • Support GQA(Grouped Query Attention).

@okoge-kaz okoge-kaz marked this pull request as ready for review January 3, 2025 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant