Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 22, 2023
1 parent af6c25c commit 286c9dd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 93 deletions.
159 changes: 74 additions & 85 deletions convert_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,37 @@ def serialize_f32(file, tensor):
file.write(struct.pack(f"{len(tensor_f32)}f", *tensor_f32))


# https://github.com/huggingface/transformers/blob/5c081e29930466ecf9a478727039d980131076d9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122C28-L122C35
def unpermute(tensor, n_heads, dim1, dim2):
return (
tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)


def write_checkpoint_file():
hf_model = AutoModelForCausalLM.from_pretrained(args.input_model_path)
model = AutoModelForCausalLM.from_pretrained(args.input_model_path)

if hf_model.config.model_type != "llama":
if model.config.model_type != "llama":
parser.error("Expected llama model")

if hf_model.config.rope_theta != 10000:
if model.config.rope_theta != 10000:
parser.error("Expected a RoPE frequency base of 10000")

hf_state_dict = hf_model.state_dict()
token_embedding_vectors = hf_state_dict["model.embed_tokens.weight"]
output_matrix = hf_state_dict[f"lm_head.weight"]
state = model.state_dict()
embedding_weights = state["model.embed_tokens.weight"]
output_norm_weight = state["model.norm.weight"]
output_weight = state[f"lm_head.weight"]

embedding_size = hf_model.config.hidden_size
ffn_hidden_size = hf_model.config.intermediate_size
n_layers = hf_model.config.num_hidden_layers
n_attention_heads = hf_model.config.num_attention_heads
n_attention_query_groups = hf_model.config.num_key_value_heads
vocab_size = hf_model.config.vocab_size
max_sequence_length = hf_model.config.max_position_embeddings
shared_output_matrix = torch.equal(token_embedding_vectors, output_matrix)
embedding_size = model.config.hidden_size
ffn_hidden_size = model.config.intermediate_size
n_layers = model.config.num_hidden_layers
n_attention_heads = model.config.num_attention_heads
n_attention_query_groups = model.config.num_key_value_heads
vocab_size = model.config.vocab_size
max_sequence_length = model.config.max_position_embeddings
shared_output_weight = torch.equal(embedding_weights, output_weight)

os.makedirs(args.output_model_path, exist_ok=True)

Expand All @@ -54,107 +64,86 @@ def write_checkpoint_file():
)
)

output_file.write(struct.pack("B", int(shared_output_matrix)))
output_file.write(struct.pack("B", int(shared_output_weight)))
output_file.write(b"\0" * (256 - output_file.tell()))

# attention_norm_vectors
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.input_layernorm.weight"]
)
attention_norm_weight = state[f"model.layers.{layer}.input_layernorm.weight"]

serialize_f32(output_file, attention_norm_weight)

# ffn_norm_vectors
for layer in range(n_layers):
serialize_f32(
output_file,
hf_state_dict[f"model.layers.{layer}.post_attention_layernorm.weight"],
)
ffn_norm_weight = state[f"model.layers.{layer}.post_attention_layernorm.weight"]

# output_norm_vector
serialize_f32(output_file, hf_state_dict["model.norm.weight"])
serialize_f32(output_file, ffn_norm_weight)

serialize_f32(output_file, token_embedding_vectors)
serialize_f32(output_file, output_norm_weight)
serialize_f32(output_file, embedding_weights)

# https://github.com/huggingface/transformers/blob/5c081e29930466ecf9a478727039d980131076d9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122C28-L122C35
def unpermute(tensor):
return (
tensor.view(
for layer in range(n_layers):
attention_query_weight = state[f"model.layers.{layer}.self_attn.q_proj.weight"]

serialize_f32(
output_file,
unpermute(
attention_query_weight,
n_attention_heads,
2,
embedding_size // n_attention_heads // 2,
embedding_size,
)
.transpose(1, 2)
.reshape(embedding_size, embedding_size)
embedding_size,
),
)

def unpermute_attention_key_matrices(tensor):
for layer in range(n_layers):
attention_key_weight = state[f"model.layers.{layer}.self_attn.k_proj.weight"]

if n_attention_heads == n_attention_query_groups:
return unpermute(tensor)
else:
key_value_size = (
embedding_size // n_attention_heads * n_attention_query_groups
serialize_f32(
output_file,
unpermute(
attention_key_weight,
n_attention_heads,
embedding_size,
embedding_size,
),
)

return (
tensor.view(
else:
serialize_f32(
output_file,
unpermute(
attention_key_weight,
n_attention_query_groups,
2,
key_value_size // n_attention_query_groups // 2,
embedding_size // n_attention_heads * n_attention_query_groups,
embedding_size,
)
.transpose(1, 2)
.reshape(key_value_size, embedding_size)
),
)

# attention_query_matrices
for layer in range(n_layers):
serialize_f32(
output_file,
unpermute(hf_state_dict[f"model.layers.{layer}.self_attn.q_proj.weight"]),
)
attention_value_weight = state[f"model.layers.{layer}.self_attn.v_proj.weight"]

# attention_key_matrices
for layer in range(n_layers):
serialize_f32(
output_file,
unpermute_attention_key_matrices(
hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"]
),
)
serialize_f32(output_file, attention_value_weight)

# attention_value_matrices
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.self_attn.v_proj.weight"]
)
attention_output_weight = state[f"model.layers.{layer}.self_attn.o_proj.weight"]

# attention_output_matrices
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.self_attn.o_proj.weight"]
)
serialize_f32(output_file, attention_output_weight)

# ffn_gate_matrices
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.mlp.gate_proj.weight"]
)
ffn_gate_weight = state[f"model.layers.{layer}.mlp.gate_proj.weight"]

serialize_f32(output_file, ffn_gate_weight)

# ffn_down_matrices
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.mlp.down_proj.weight"]
)
ffn_down_weight = state[f"model.layers.{layer}.mlp.down_proj.weight"]

serialize_f32(output_file, ffn_down_weight)

# ffn_up_matrices
for layer in range(n_layers):
serialize_f32(
output_file, hf_state_dict[f"model.layers.{layer}.mlp.up_proj.weight"]
)
ffn_up_weight = state[f"model.layers.{layer}.mlp.up_proj.weight"]

serialize_f32(output_file, ffn_up_weight)

if not shared_output_matrix:
serialize_f32(output_file, output_matrix)
if not shared_output_weight:
serialize_f32(output_file, output_weight)

output_file.close()

Expand Down
12 changes: 6 additions & 6 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ n_attention_query_groups: usize,
vocab_size: usize,
max_sequence_length: usize,

token_embedding_weights: []const Vector,
embedding_weights: []const Vector,
attention_norm_weights: []const Vector,
attention_query_weights: []const Matrix,
attention_key_weights: []const Matrix,
Expand Down Expand Up @@ -52,7 +52,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self {
const n_attention_query_groups: usize = @intCast(try file.reader().readIntLittle(i32));
const vocab_size: usize = @intCast(try file.reader().readIntLittle(i32));
const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32));
const shared_output_matrix = try file.reader().readIntLittle(u8) == 1;
const shared_output_weight = try file.reader().readIntLittle(u8) == 1;

try file.seekTo(256);

Expand All @@ -72,7 +72,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self {

const output_norm_weight = try Vector.readLeaky(allocator, file, embedding_size);

const token_embedding_weights = try Vector.readMultipleLeaky(
const embedding_weights = try Vector.readMultipleLeaky(
allocator,
file,
vocab_size,
Expand Down Expand Up @@ -137,8 +137,8 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self {
embedding_size,
);

const output_weight = if (shared_output_matrix)
Matrix{ .rows = token_embedding_weights }
const output_weight = if (shared_output_weight)
Matrix{ .rows = embedding_weights }
else
try Matrix.readLeaky(allocator, file, vocab_size, embedding_size);

Expand All @@ -151,7 +151,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self {
.vocab_size = vocab_size,
.max_sequence_length = max_sequence_length,

.token_embedding_weights = token_embedding_weights,
.embedding_weights = embedding_weights,
.attention_norm_weights = attention_norm_weights,
.attention_query_weights = attention_query_weights,
.attention_key_weights = attention_key_weights,
Expand Down
4 changes: 2 additions & 2 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ pub fn createLeaky(
}

pub fn forward(self: Self, token: usize, position: usize) !void {
const token_embedding_weight = self.checkpoint.token_embedding_weights[token];
const embedding_weight = self.checkpoint.embedding_weights[token];

@memcpy(self.hidden.values, token_embedding_weight.values);
@memcpy(self.hidden.values, embedding_weight.values);

for (0..self.checkpoint.n_layers) |layer| {
const attention_norm_weight = self.checkpoint.attention_norm_weights[layer];
Expand Down

0 comments on commit 286c9dd

Please sign in to comment.