Skip to content

Commit

Permalink
remove layer norm flops
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak authored Sep 13, 2024
1 parent 73da086 commit 67115a5
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,6 @@ def get_block_compute_costs(self):
LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection
+ 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV
+ 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism)
+ 2 * model_config.hidden_size, # for the layernorm
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
Expand Down Expand Up @@ -1172,4 +1171,4 @@ def get_flops(

hardware_flops = model_flops # TODO: This is a placeholder for now

return model_flops, hardware_flops
return model_flops, hardware_flops

0 comments on commit 67115a5

Please sign in to comment.