diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 7a55da26ef..f88fca6407 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -25,25 +25,29 @@ def __init__(self, hf_config): self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 self.bias = hf_config.bias + self.hidden_dropout = hf_config.hidden_dropout self.hidden_size = hf_config.hidden_size self.layer_norm_epsilon = hf_config.layer_norm_epsilon - self.multi_query = hf_config.multi_query - self.n_head = ( - hf_config.n_head - if "n_head" in hf_config.__dict__ - else hf_config.num_attention_heads + self.multi_query = ( + hf_config.multi_query if "multi_query" in hf_config.__dict__ else True ) - self.n_head_kv = hf_config.n_head_kv if "n_head_kv" in hf_config.__dict__ else 1 - self.n_layer = ( - hf_config.n_layer - if "n_layer" in hf_config.__dict__ - else hf_config.num_hidden_layers + self.new_decoder_architecture = hf_config.new_decoder_architecture + + self.num_attention_heads = hf_config.num_attention_heads + self.num_kv_heads = ( + hf_config.num_kv_heads + if (self.new_decoder_architecture or not self.multi_query) + else 1 ) + self.head_dim = self.hidden_size // self.num_attention_heads + + self.n_layer = hf_config.num_hidden_layers self.parallel_attn = hf_config.parallel_attn self.vocab_size = hf_config.vocab_size + # Standardized FlexFlow num heads fields below - self.num_attention_heads = self.n_head - self.num_key_value_heads = self.n_head_kv + # self.num_attention_heads = self.n_head + self.num_key_value_heads = self.num_kv_heads class FlexFlowFalcon(FlexFlowModel): @@ -76,16 +80,19 @@ def __init__( ) # Sanity checks - if self.falcon_config.hidden_size % self.falcon_config.n_head != 0: + if self.falcon_config.hidden_size % self.falcon_config.num_attention_heads != 0: raise ValueError( - f"Hidden size ({self.falcon_config.hidden_size}) is not divisible by n_head ({self.falcon_config.n_head})" + f"Hidden size ({self.falcon_config.hidden_size}) is not divisible by num_attention_heads ({self.falcon_config.num_attention_heads})" ) if ( - self.falcon_config.n_head < self.ffconfig.tensor_parallelism_degree - or self.falcon_config.n_head % self.ffconfig.tensor_parallelism_degree != 0 + self.falcon_config.num_attention_heads + < self.ffconfig.tensor_parallelism_degree + or self.falcon_config.num_attention_heads + % self.ffconfig.tensor_parallelism_degree + != 0 ): raise ValueError( - f"Number of q attention heads ({self.falcon_config.n_head}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" + f"Number of q attention heads ({self.falcon_config.num_attention_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) self.build_model( @@ -124,7 +131,9 @@ def build_model(self, max_tokens_per_batch): axes, True, self.falcon_config.layer_norm_epsilon, - name=f"layers_{i}_input_layernorm", + name=f"layers_{i}_input_layernorm" + if not self.falcon_config.new_decoder_architecture + else f"layers_{i}_ln_attn", ) else: token, att_norm = ffmodel.residual_layer_norm( @@ -135,17 +144,32 @@ def build_model(self, max_tokens_per_batch): axes, True, self.falcon_config.layer_norm_epsilon, - name=f"layers_{i}_input_layernorm", + name=f"layers_{i}_input_layernorm" + if not self.falcon_config.new_decoder_architecture + else f"layers_{i}_ln_attn", + ) + + # MLP norm (identical to att norm for old architecture) + if not self.falcon_config.new_decoder_architecture: + mlp_norm = att_norm + else: + # Residual has already computed by attn norm (token = token + mha + mlp_output) + mlp_norm = ffmodel.layer_norm( + token, + axes, + True, + self.falcon_config.layer_norm_epsilon, + name=f"layers_{i}_ln_mlp", ) if self.mode == InferenceMode.BEAM_SEARCH_MODE: mha = ffmodel.spec_inc_multiquery_self_attention( att_norm, self.falcon_config.hidden_size, - self.falcon_config.n_head, - self.falcon_config.n_head_kv, - self.falcon_config.hidden_size // self.falcon_config.n_head, - self.falcon_config.hidden_size // self.falcon_config.n_head, + self.falcon_config.num_attention_heads, + self.falcon_config.num_kv_heads, + self.falcon_config.head_dim, + self.falcon_config.head_dim, 0.0, # dropout False, # qkv_bias False, # final_bias @@ -159,10 +183,10 @@ def build_model(self, max_tokens_per_batch): mha = ffmodel.inc_multiquery_self_attention_verify( att_norm, self.falcon_config.hidden_size, - self.falcon_config.n_head, - self.falcon_config.n_head_kv, - self.falcon_config.hidden_size // self.falcon_config.n_head, - self.falcon_config.hidden_size // self.falcon_config.n_head, + self.falcon_config.num_attention_heads, + self.falcon_config.num_kv_heads, + self.falcon_config.head_dim, + self.falcon_config.head_dim, 0.0, # dropout False, # qkv_bias False, # final_bias @@ -176,10 +200,10 @@ def build_model(self, max_tokens_per_batch): mha = ffmodel.inc_multiquery_self_attention( att_norm, self.falcon_config.hidden_size, - self.falcon_config.n_head, - self.falcon_config.n_head_kv, - self.falcon_config.hidden_size // self.falcon_config.n_head, - self.falcon_config.hidden_size // self.falcon_config.n_head, + self.falcon_config.num_attention_heads, + self.falcon_config.num_kv_heads, + self.falcon_config.head_dim, + self.falcon_config.head_dim, 0.0, # dropout False, # qkv_bias False, # final_bias @@ -193,7 +217,7 @@ def build_model(self, max_tokens_per_batch): assert False dense_h_to_4h = ffmodel.dense( - att_norm, + mlp_norm, self.falcon_config.hidden_size * 4, ActiMode.AC_MODE_NONE, False, @@ -245,10 +269,10 @@ def build_model(self, max_tokens_per_batch): def convert_hf_model(model, dst_folder): os.makedirs(dst_folder, exist_ok=True) - n_head = ( - model.config.n_head - if "n_head" in model.config.__dict__ - else model.config.num_attention_heads + num_kv_heads = ( + model.config.num_kv_heads + if (model.config.new_decoder_architecture or not model.config.multi_query) + else 1 ) for name, params in model.named_parameters(): name = ( @@ -262,12 +286,13 @@ def convert_hf_model(model, dst_folder): name_q = name.replace("self_attention_query_key_value", "attention_wq") name_k = name.replace("self_attention_query_key_value", "attention_wk") name_v = name.replace("self_attention_query_key_value", "attention_wv") + # We split first dim of tensor, which is the output dimension. Second dimension is the input dimension, and is always equal to the hidden size q, k, v = torch.split( params, [ - model.config.hidden_size, - model.config.hidden_size // n_head, - model.config.hidden_size // n_head, + model.config.head_dim * model.config.num_attention_heads, + model.config.head_dim * num_kv_heads, + model.config.head_dim * num_kv_heads, ], 0, )