Skip to content

Commit

Permalink
Add converter executable
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 18, 2023
1 parent fba127b commit 0112f2e
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 7 deletions.
16 changes: 16 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,47 @@ pub fn build(b: *std.Build) void {
.optimize = optimize,
});

const converter_exe = b.addExecutable(.{
.name = "llama2-converter",
.root_source_file = .{ .path = "src/converter_main.zig" },
.target = target,
.optimize = optimize,
});

const build_options = b.addOptions();

chat_exe.addOptions("build_options", build_options);
generator_exe.addOptions("build_options", build_options);
converter_exe.addOptions("build_options", build_options);

// This declares intent for the executable to be installed into the
// standard location when the user invokes the "install" step (the default
// step when running `zig build`).
b.installArtifact(chat_exe);
b.installArtifact(generator_exe);
b.installArtifact(converter_exe);

// This *creates* a Run step in the build graph, to be executed when another
// step is evaluated that depends on it. The next line below will establish
// such a dependency.
const run_chat_cmd = b.addRunArtifact(chat_exe);
const run_generator_cmd = b.addRunArtifact(generator_exe);
const run_converter_cmd = b.addRunArtifact(converter_exe);

// By making the run step depend on the install step, it will be run from the
// installation directory rather than directly from within the cache directory.
// This is not necessary, however, if the application depends on other installed
// files, this ensures they will be present and in the expected location.
run_chat_cmd.step.dependOn(b.getInstallStep());
run_generator_cmd.step.dependOn(b.getInstallStep());
run_converter_cmd.step.dependOn(b.getInstallStep());

// This allows the user to pass arguments to the application in the build
// command itself, like this: `zig build run -- arg1 arg2 etc`
if (b.args) |args| {
run_chat_cmd.addArgs(args);
run_generator_cmd.addArgs(args);
run_converter_cmd.addArgs(args);
}

// This creates a build step. It will be visible in the `zig build --help` menu,
Expand All @@ -60,6 +72,10 @@ pub fn build(b: *std.Build) void {

run_generator_step.dependOn(&run_generator_cmd.step);

const run_converter_step = b.step("run-converter", "Run the converter");

run_converter_step.dependOn(&run_converter_cmd.step);

const test_step = b.step("test", "Run unit tests");

const generator_tests = b.addTest(.{
Expand Down
Binary file added models/tinystories_15m/checkpoint_v1.bin
Binary file not shown.
Binary file added models/tinystories_260k/checkpoint_v1.bin
Binary file not shown.
8 changes: 5 additions & 3 deletions src/chat_main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ const Chat = @import("chat.zig");
const ChatArgs = @import("chat_args.zig");

pub fn main() !void {
var args = try ChatArgs.init(std.heap.page_allocator);
const allocator = std.heap.page_allocator;

var args = try ChatArgs.init(allocator);

defer args.deinit();

var chat = try Chat.init(std.heap.page_allocator, args);
var chat = try Chat.init(allocator, args);

defer chat.deinit();

try chat.start(std.heap.page_allocator);
try chat.start(allocator);
}
45 changes: 43 additions & 2 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,48 @@ pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self {
}

// https://github.com/karpathy/llama2.c/blob/d9862069e7ef665fe6309e3c17398ded2f121bf5/export.py#L132
pub fn readV1(allocator: std.mem.Allocator, file: std.fs.File) !Self {
pub fn writeV1(self: *const Self, allocator: std.mem.Allocator, model_path: []const u8) !void {
const path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_v1.bin" },
);

defer allocator.free(path);

const file = try std.fs.cwd().createFile(path, .{ .truncate = true });

defer file.close();

try file.writer().writeAll("ak42");
try file.writer().writeIntLittle(i32, 1);
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.embedding_size)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.ffn_hidden_size)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.n_layers)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.n_attention_heads)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.n_attention_query_groups)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.vocab_size)));
try file.writer().writeIntLittle(i32, @as(i32, @intCast(self.max_sequence_length)));
try file.writer().writeIntLittle(u8, @as(u8, @intFromBool(self.shared_output_matrix)));
try file.writer().writeByteNTimes(0, 256 - try file.getPos());
try self.weights.attention_norm_vectors.write(file);
try self.weights.ffn_norm_vectors.write(file);
try self.weights.output_norm_vector.write(file);
try self.weights.token_embedding_vectors.write(file);
try self.weights.attention_query_matrices.write(file);
try self.weights.attention_key_matrices.write(file);
try self.weights.attention_value_matrices.write(file);
try self.weights.attention_output_matrices.write(file);
try self.weights.ffn_gate_matrices.write(file);
try self.weights.ffn_down_matrices.write(file);
try self.weights.ffn_up_matrices.write(file);

if (!self.shared_output_matrix) {
try self.weights.output_matrix.write(file);
}
}

// https://github.com/karpathy/llama2.c/blob/d9862069e7ef665fe6309e3c17398ded2f121bf5/export.py#L132
fn readV1(allocator: std.mem.Allocator, file: std.fs.File) !Self {
const magic: [*]const u8 = @ptrCast(&try file.reader().readIntLittle(u32));

if (!std.mem.eql(u8, magic[0..4], "ak42")) {
Expand Down Expand Up @@ -216,7 +257,7 @@ pub fn readV1(allocator: std.mem.Allocator, file: std.fs.File) !Self {
}

// https://github.com/karpathy/llama2.c/blob/d9862069e7ef665fe6309e3c17398ded2f121bf5/export.py#L75
pub fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self {
fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self {
const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32));
const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
const n_layers: usize = @intCast(try file.reader().readIntLittle(i32));
Expand Down
42 changes: 42 additions & 0 deletions src/converter_args.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
const Self = @This();

const std = @import("std");

arg_iterator: std.process.ArgIterator,
model_path: []const u8,

const Option = enum { temperature, top_p, random_seed, n_steps, prompt };

pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);

errdefer arg_iterator.deinit();

_ = arg_iterator.next().?;

const model_path = arg_iterator.next() orelse try help(1);

while (arg_iterator.next()) |arg| {
try help(if (std.mem.eql(u8, arg, "--help")) 0 else 1);
}

return Self{ .arg_iterator = arg_iterator, .model_path = model_path };
}

pub fn deinit(self: *Self) void {
self.arg_iterator.deinit();
}

fn help(exit_status: u8) !noreturn {
const console = if (exit_status == 0)
std.io.getStdOut().writer()
else
std.io.getStdErr().writer();

try console.print("Usage: llama2-converter <model_path> [options]\n\n", .{});

try console.print("Options:\n", .{});
try console.print(" --help\n", .{});

std.process.exit(exit_status);
}
17 changes: 17 additions & 0 deletions src/converter_main.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
const std = @import("std");
const Checkpoint = @import("checkpoint.zig");
const ConverterArgs = @import("converter_args.zig");

pub fn main() !void {
const allocator = std.heap.page_allocator;

var args = try ConverterArgs.init(allocator);

defer args.deinit();

const checkpoint = try Checkpoint.init(allocator, args.model_path);

defer checkpoint.deinit();

try checkpoint.writeV1(allocator, args.model_path);
}
6 changes: 4 additions & 2 deletions src/generator_main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ const Generator = @import("generator.zig");
const GeneratorArgs = @import("generator_args.zig");

pub fn main() !void {
var args = try GeneratorArgs.init(std.heap.page_allocator);
const allocator = std.heap.page_allocator;

var args = try GeneratorArgs.init(allocator);

defer args.deinit();

var generator = try Generator.init(std.heap.page_allocator, args);
var generator = try Generator.init(allocator, args);

defer generator.deinit();

Expand Down
6 changes: 6 additions & 0 deletions src/tensor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
try file.reader().readNoEof(buffer[0 .. self.data.len * @sizeOf(f32)]);
}

pub fn write(self: *const Self, file: std.fs.File) !void {
const buffer: [*]u8 = @ptrCast(self.data);

try file.writer().writeAll(buffer[0 .. self.data.len * @sizeOf(f32)]);
}

pub fn slice(self: *const Self, index: usize) Tensor(n_dims - 1) {
comptime if (n_dims < 2) @compileError("n_dims < 2");

Expand Down

0 comments on commit 0112f2e

Please sign in to comment.