From fa0e8a58cface4f563faaa0f9e1001781f89a09d Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Thu, 19 Oct 2023 11:49:26 +0200 Subject: [PATCH] Support converting GQA HF models like TinyLlama-1.1B See: https://github.com/karpathy/llama2.c/issues/431 --- convert_hf_model.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/convert_hf_model.py b/convert_hf_model.py index e45240e..12fc0d8 100644 --- a/convert_hf_model.py +++ b/convert_hf_model.py @@ -88,6 +88,25 @@ def unpermute(tensor): .reshape(embedding_size, embedding_size) ) + def unpermute_attention_key_matrices(tensor): + if n_attention_heads == n_attention_query_groups: + return unpermute(tensor) + else: + key_value_size = ( + embedding_size // n_attention_heads * n_attention_query_groups + ) + + return ( + tensor.view( + n_attention_query_groups, + 2, + key_value_size // n_attention_query_groups // 2, + embedding_size, + ) + .transpose(1, 2) + .reshape(key_value_size, embedding_size) + ) + # attention_query_matrices for layer in range(n_layers): serialize_f32( @@ -99,7 +118,9 @@ def unpermute(tensor): for layer in range(n_layers): serialize_f32( output_file, - unpermute(hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"]), + unpermute_attention_key_matrices( + hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"] + ), ) # attention_value_matrices