Skip to content

Commit

Permalink
Add support for V1 checkpoint file format
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 18, 2023
1 parent 5c2cd57 commit 5aee6ce
Showing 1 changed file with 169 additions and 0 deletions.
169 changes: 169 additions & 0 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ weights: struct {
},

pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self {
const v1_path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_v1.bin" },
);

defer allocator.free(v1_path);

const v1_file = std.fs.cwd().openFile(v1_path, .{}) catch null;

defer if (v1_file) |file| file.close();

if (v1_file) |file| return try readV1(allocator, file);

const legacy_path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_legacy.bin" },
Expand All @@ -46,6 +59,162 @@ pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self {
return error.CheckpointFileNotFound;
}

// https://github.com/karpathy/llama2.c/blob/d9862069e7ef665fe6309e3c17398ded2f121bf5/export.py#L132
pub fn readV1(allocator: std.mem.Allocator, file: std.fs.File) !Self {
const magic: [*]const u8 = @ptrCast(&try file.reader().readIntLittle(u32));

if (!std.mem.eql(u8, magic[0..4], "ak42")) {
return error.InvalidMagic;
}

const version = try file.reader().readIntLittle(i32);

if (version != 1) {
return error.InvalidVersion;
}

const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32));
const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
const n_layers: usize = @intCast(try file.reader().readIntLittle(i32));
const n_attention_heads: usize = @intCast(try file.reader().readIntLittle(i32));
const n_attention_query_groups: usize = @intCast(try file.reader().readIntLittle(i32));
const vocab_size: usize = @intCast(try file.reader().readIntLittle(i32));
const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32));
const shared_output_matrix = try file.reader().readIntLittle(u8) == 1;

try file.seekTo(256);

const attention_norm_vectors = try Tensor(2).init(
allocator,
[_]usize{ n_layers, embedding_size },
);

errdefer attention_norm_vectors.deinit();
try attention_norm_vectors.read(file);

const ffn_norm_vectors = try Tensor(2).init(
allocator,
[_]usize{ n_layers, embedding_size },
);

errdefer ffn_norm_vectors.deinit();
try ffn_norm_vectors.read(file);

const output_norm_vector = try Tensor(1).init(
allocator,
[_]usize{embedding_size},
);

errdefer output_norm_vector.deinit();
try output_norm_vector.read(file);

const token_embedding_vectors = try Tensor(2).init(
allocator,
[_]usize{ vocab_size, embedding_size },
);

errdefer token_embedding_vectors.deinit();
try token_embedding_vectors.read(file);

const attention_query_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, embedding_size, embedding_size },
);

errdefer attention_query_matrices.deinit();
try attention_query_matrices.read(file);

const attention_head_size: usize = embedding_size / n_attention_heads;

const attention_key_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, n_attention_query_groups * attention_head_size, embedding_size },
);

errdefer attention_key_matrices.deinit();
try attention_key_matrices.read(file);

const attention_value_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, n_attention_query_groups * attention_head_size, embedding_size },
);

errdefer attention_value_matrices.deinit();
try attention_value_matrices.read(file);

const attention_output_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, embedding_size, embedding_size },
);

errdefer attention_output_matrices.deinit();
try attention_output_matrices.read(file);

const ffn_gate_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, ffn_hidden_size, embedding_size },
);

errdefer ffn_gate_matrices.deinit();
try ffn_gate_matrices.read(file);

const ffn_down_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, embedding_size, ffn_hidden_size },
);

errdefer ffn_down_matrices.deinit();
try ffn_down_matrices.read(file);

const ffn_up_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, ffn_hidden_size, embedding_size },
);

errdefer ffn_up_matrices.deinit();
try ffn_up_matrices.read(file);

const output_matrix = if (shared_output_matrix)
token_embedding_vectors
else
try Tensor(2).init(allocator, [_]usize{ vocab_size, embedding_size });

errdefer if (!shared_output_matrix) {
output_matrix.deinit();
};

if (!shared_output_matrix) {
try output_matrix.read(file);
}

return Self{
.allocator = allocator,
.embedding_size = embedding_size,
.ffn_hidden_size = ffn_hidden_size,
.n_layers = n_layers,
.n_attention_heads = n_attention_heads,
.n_attention_query_groups = n_attention_query_groups,
.vocab_size = vocab_size,
.max_sequence_length = max_sequence_length,
.shared_output_matrix = shared_output_matrix,

.weights = .{
.token_embedding_vectors = token_embedding_vectors,
.attention_norm_vectors = attention_norm_vectors,
.attention_query_matrices = attention_query_matrices,
.attention_key_matrices = attention_key_matrices,
.attention_value_matrices = attention_value_matrices,
.attention_output_matrices = attention_output_matrices,
.ffn_norm_vectors = ffn_norm_vectors,
.ffn_gate_matrices = ffn_gate_matrices,
.ffn_down_matrices = ffn_down_matrices,
.ffn_up_matrices = ffn_up_matrices,
.output_norm_vector = output_norm_vector,
.output_matrix = output_matrix,
},
};
}

pub fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self {
const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32));
const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
Expand Down

0 comments on commit 5aee6ce

Please sign in to comment.