Skip to content

Commit

Permalink
Support converting GQA HF models like TinyLlama-1.1B
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent c747bf5 commit fa0e8a5
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion convert_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ def unpermute(tensor):
.reshape(embedding_size, embedding_size)
)

def unpermute_attention_key_matrices(tensor):
if n_attention_heads == n_attention_query_groups:
return unpermute(tensor)
else:
key_value_size = (
embedding_size // n_attention_heads * n_attention_query_groups
)

return (
tensor.view(
n_attention_query_groups,
2,
key_value_size // n_attention_query_groups // 2,
embedding_size,
)
.transpose(1, 2)
.reshape(key_value_size, embedding_size)
)

# attention_query_matrices
for layer in range(n_layers):
serialize_f32(
Expand All @@ -99,7 +118,9 @@ def unpermute(tensor):
for layer in range(n_layers):
serialize_f32(
output_file,
unpermute(hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"]),
unpermute_attention_key_matrices(
hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"]
),
)

# attention_value_matrices
Expand Down

0 comments on commit fa0e8a5

Please sign in to comment.