From 018ece5af2d89a11a4a235f81f94496c78b4f990 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Thu, 16 Jan 2025 10:11:07 -0800 Subject: [PATCH] Add extra_repr to Linear classes for debugging purpose (#6954) **Summary** This PR adds `extra_repr` method to some Linear classes so that additional info is printed when printing such modules. It is useful for debugging. Affected modules: - LinearLayer - LinearAllreduce - LmHeadLinearAllreduce The `extra_repr` method gives the following info: - in_features - out_features - bias (true or false) - dtype **Example** Print llama-2-7b model on rank 0 after `init_inference` with world size = 2. Previously we only got class names of these modules: ``` InferenceEngine( (module): LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): LinearLayer() (k_proj): LinearLayer() (v_proj): LinearLayer() (o_proj): LinearAllreduce() (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): LinearLayer() (up_proj): LinearLayer() (down_proj): LinearAllreduce() (act_fn): SiLU() ) (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) ) ) (norm): LlamaRMSNorm((4096,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): LmHeadLinearAllreduce() ) ) ``` Now we get more useful info: ``` InferenceEngine( (module): LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (k_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (v_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (o_proj): LinearAllreduce(in_features=2048, out_features=4096, bias=False, dtype=torch.bfloat16) (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16) (up_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16) (down_proj): LinearAllreduce(in_features=5504, out_features=4096, bias=False, dtype=torch.bfloat16) (act_fn): SiLU() ) (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) ) ) (norm): LlamaRMSNorm((4096,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): LmHeadLinearAllreduce(in_features=2048, out_features=32000, bias=False, dtype=torch.bfloat16) ) ) ``` --- deepspeed/module_inject/layers.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 722ba413a671..2f884ba4fb09 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -91,6 +91,13 @@ def forward(self, input): output += self.bias return output + def extra_repr(self): + out_features, in_features = self.weight.shape if self.weight is not None else (None, None) + dtype = self.weight.dtype if self.weight is not None else None + extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( + in_features, out_features, self.bias is not None, dtype) + return extra_repr_str + class LmHeadLinearAllreduce(nn.Module): @@ -120,6 +127,13 @@ def forward(self, input): output += self.bias return output + def extra_repr(self): + out_features, in_features = self.weight.shape if self.weight is not None else (None, None) + dtype = self.weight.dtype if self.weight is not None else None + extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( + in_features, out_features, self.bias is not None, dtype) + return extra_repr_str + class LinearLayer(nn.Module): @@ -144,6 +158,13 @@ def forward(self, input): output += self.bias return output + def extra_repr(self): + out_features, in_features = self.weight.shape + dtype = self.weight.dtype + extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( + in_features, out_features, self.bias is not None, dtype) + return extra_repr_str + class Normalize(nn.Module):