diff --git a/convert_hf_model.py b/convert_hf_model.py index 12fc0d8..2980935 100644 --- a/convert_hf_model.py +++ b/convert_hf_model.py @@ -12,27 +12,37 @@ def serialize_f32(file, tensor): file.write(struct.pack(f"{len(tensor_f32)}f", *tensor_f32)) +# https://github.com/huggingface/transformers/blob/5c081e29930466ecf9a478727039d980131076d9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122C28-L122C35 +def unpermute(tensor, n_heads, dim1, dim2): + return ( + tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + def write_checkpoint_file(): - hf_model = AutoModelForCausalLM.from_pretrained(args.input_model_path) + model = AutoModelForCausalLM.from_pretrained(args.input_model_path) - if hf_model.config.model_type != "llama": + if model.config.model_type != "llama": parser.error("Expected llama model") - if hf_model.config.rope_theta != 10000: + if model.config.rope_theta != 10000: parser.error("Expected a RoPE frequency base of 10000") - hf_state_dict = hf_model.state_dict() - token_embedding_vectors = hf_state_dict["model.embed_tokens.weight"] - output_matrix = hf_state_dict[f"lm_head.weight"] + state = model.state_dict() + embedding_weights = state["model.embed_tokens.weight"] + output_norm_weight = state["model.norm.weight"] + output_weight = state[f"lm_head.weight"] - embedding_size = hf_model.config.hidden_size - ffn_hidden_size = hf_model.config.intermediate_size - n_layers = hf_model.config.num_hidden_layers - n_attention_heads = hf_model.config.num_attention_heads - n_attention_query_groups = hf_model.config.num_key_value_heads - vocab_size = hf_model.config.vocab_size - max_sequence_length = hf_model.config.max_position_embeddings - shared_output_matrix = torch.equal(token_embedding_vectors, output_matrix) + embedding_size = model.config.hidden_size + ffn_hidden_size = model.config.intermediate_size + n_layers = model.config.num_hidden_layers + n_attention_heads = model.config.num_attention_heads + n_attention_query_groups = model.config.num_key_value_heads + vocab_size = model.config.vocab_size + max_sequence_length = model.config.max_position_embeddings + shared_output_weight = torch.equal(embedding_weights, output_weight) os.makedirs(args.output_model_path, exist_ok=True) @@ -54,107 +64,86 @@ def write_checkpoint_file(): ) ) - output_file.write(struct.pack("B", int(shared_output_matrix))) + output_file.write(struct.pack("B", int(shared_output_weight))) output_file.write(b"\0" * (256 - output_file.tell())) - # attention_norm_vectors for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.input_layernorm.weight"] - ) + attention_norm_weight = state[f"model.layers.{layer}.input_layernorm.weight"] + + serialize_f32(output_file, attention_norm_weight) - # ffn_norm_vectors for layer in range(n_layers): - serialize_f32( - output_file, - hf_state_dict[f"model.layers.{layer}.post_attention_layernorm.weight"], - ) + ffn_norm_weight = state[f"model.layers.{layer}.post_attention_layernorm.weight"] - # output_norm_vector - serialize_f32(output_file, hf_state_dict["model.norm.weight"]) + serialize_f32(output_file, ffn_norm_weight) - serialize_f32(output_file, token_embedding_vectors) + serialize_f32(output_file, output_norm_weight) + serialize_f32(output_file, embedding_weights) - # https://github.com/huggingface/transformers/blob/5c081e29930466ecf9a478727039d980131076d9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122C28-L122C35 - def unpermute(tensor): - return ( - tensor.view( + for layer in range(n_layers): + attention_query_weight = state[f"model.layers.{layer}.self_attn.q_proj.weight"] + + serialize_f32( + output_file, + unpermute( + attention_query_weight, n_attention_heads, - 2, - embedding_size // n_attention_heads // 2, embedding_size, - ) - .transpose(1, 2) - .reshape(embedding_size, embedding_size) + embedding_size, + ), ) - def unpermute_attention_key_matrices(tensor): + for layer in range(n_layers): + attention_key_weight = state[f"model.layers.{layer}.self_attn.k_proj.weight"] + if n_attention_heads == n_attention_query_groups: - return unpermute(tensor) - else: - key_value_size = ( - embedding_size // n_attention_heads * n_attention_query_groups + serialize_f32( + output_file, + unpermute( + attention_key_weight, + n_attention_heads, + embedding_size, + embedding_size, + ), ) - - return ( - tensor.view( + else: + serialize_f32( + output_file, + unpermute( + attention_key_weight, n_attention_query_groups, - 2, - key_value_size // n_attention_query_groups // 2, + embedding_size // n_attention_heads * n_attention_query_groups, embedding_size, - ) - .transpose(1, 2) - .reshape(key_value_size, embedding_size) + ), ) - # attention_query_matrices for layer in range(n_layers): - serialize_f32( - output_file, - unpermute(hf_state_dict[f"model.layers.{layer}.self_attn.q_proj.weight"]), - ) + attention_value_weight = state[f"model.layers.{layer}.self_attn.v_proj.weight"] - # attention_key_matrices - for layer in range(n_layers): - serialize_f32( - output_file, - unpermute_attention_key_matrices( - hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"] - ), - ) + serialize_f32(output_file, attention_value_weight) - # attention_value_matrices for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.self_attn.v_proj.weight"] - ) + attention_output_weight = state[f"model.layers.{layer}.self_attn.o_proj.weight"] - # attention_output_matrices - for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.self_attn.o_proj.weight"] - ) + serialize_f32(output_file, attention_output_weight) - # ffn_gate_matrices for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.mlp.gate_proj.weight"] - ) + ffn_gate_weight = state[f"model.layers.{layer}.mlp.gate_proj.weight"] + + serialize_f32(output_file, ffn_gate_weight) - # ffn_down_matrices for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.mlp.down_proj.weight"] - ) + ffn_down_weight = state[f"model.layers.{layer}.mlp.down_proj.weight"] + + serialize_f32(output_file, ffn_down_weight) - # ffn_up_matrices for layer in range(n_layers): - serialize_f32( - output_file, hf_state_dict[f"model.layers.{layer}.mlp.up_proj.weight"] - ) + ffn_up_weight = state[f"model.layers.{layer}.mlp.up_proj.weight"] + + serialize_f32(output_file, ffn_up_weight) - if not shared_output_matrix: - serialize_f32(output_file, output_matrix) + if not shared_output_weight: + serialize_f32(output_file, output_weight) output_file.close() diff --git a/src/checkpoint.zig b/src/checkpoint.zig index b7e6993..b296146 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -12,7 +12,7 @@ n_attention_query_groups: usize, vocab_size: usize, max_sequence_length: usize, -token_embedding_weights: []const Vector, +embedding_weights: []const Vector, attention_norm_weights: []const Vector, attention_query_weights: []const Matrix, attention_key_weights: []const Matrix, @@ -52,7 +52,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { const n_attention_query_groups: usize = @intCast(try file.reader().readIntLittle(i32)); const vocab_size: usize = @intCast(try file.reader().readIntLittle(i32)); const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32)); - const shared_output_matrix = try file.reader().readIntLittle(u8) == 1; + const shared_output_weight = try file.reader().readIntLittle(u8) == 1; try file.seekTo(256); @@ -72,7 +72,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { const output_norm_weight = try Vector.readLeaky(allocator, file, embedding_size); - const token_embedding_weights = try Vector.readMultipleLeaky( + const embedding_weights = try Vector.readMultipleLeaky( allocator, file, vocab_size, @@ -137,8 +137,8 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { embedding_size, ); - const output_weight = if (shared_output_matrix) - Matrix{ .rows = token_embedding_weights } + const output_weight = if (shared_output_weight) + Matrix{ .rows = embedding_weights } else try Matrix.readLeaky(allocator, file, vocab_size, embedding_size); @@ -151,7 +151,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { .vocab_size = vocab_size, .max_sequence_length = max_sequence_length, - .token_embedding_weights = token_embedding_weights, + .embedding_weights = embedding_weights, .attention_norm_weights = attention_norm_weights, .attention_query_weights = attention_query_weights, .attention_key_weights = attention_key_weights, diff --git a/src/transformer.zig b/src/transformer.zig index 2d2d64b..280bd3b 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -36,9 +36,9 @@ pub fn createLeaky( } pub fn forward(self: Self, token: usize, position: usize) !void { - const token_embedding_weight = self.checkpoint.token_embedding_weights[token]; + const embedding_weight = self.checkpoint.embedding_weights[token]; - @memcpy(self.hidden.values, token_embedding_weight.values); + @memcpy(self.hidden.values, embedding_weight.values); for (0..self.checkpoint.n_layers) |layer| { const attention_norm_weight = self.checkpoint.attention_norm_weights[layer];