From e7542cd0d2d852c0c922a1057c7a104812837b50 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Sat, 26 Aug 2023 13:18:17 +0200 Subject: [PATCH] Introduce generator --- src/generator.zig | 133 ++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 115 ++------------------------------------ src/sampler.zig | 2 +- src/transformer.zig | 2 +- 4 files changed, 141 insertions(+), 111 deletions(-) create mode 100644 src/generator.zig diff --git a/src/generator.zig b/src/generator.zig new file mode 100644 index 0000000..9299f0a --- /dev/null +++ b/src/generator.zig @@ -0,0 +1,133 @@ +const Self = @This(); + +const std = @import("std"); +const lib = @import("lib.zig"); +const Checkpoint = @import("checkpoint.zig"); +const Cli = @import("cli.zig"); +const Sampler = @import("sampler.zig"); +const Tokenizer = @import("tokenizer.zig"); +const Transformer = @import("transformer.zig"); + +allocator: std.mem.Allocator, +transformer: Transformer, +tokenizer: Tokenizer, +sampler: Sampler, +prompt_tokens: []usize, +n_steps: usize, +timer: bool, + +pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { + const transformer = try Transformer.init(allocator, cli); + + errdefer transformer.deinit(); + + const vocab_size = transformer.checkpoint.vocab_size; + const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size); + + errdefer tokenizer.deinit(); + + const sampler = try Sampler.init(allocator, cli, vocab_size); + + errdefer sampler.deinit(); + + return Self{ + .allocator = allocator, + .transformer = transformer, + .tokenizer = tokenizer, + .sampler = sampler, + .prompt_tokens = try tokenizer.encode(allocator, cli.prompt, true, false), + .n_steps = cli.n_steps, + .timer = cli.timer, + }; +} + +pub fn deinit(self: *const Self) void { + self.transformer.deinit(); + self.tokenizer.deinit(); + self.sampler.deinit(); + self.allocator.free(self.prompt_tokens); +} + +pub fn generate(self: *Self, writer: anytype) !void { + std.debug.assert(self.prompt_tokens.len > 0); + + var prompt_tokens_offset: usize = 0; + var current_token: usize = self.prompt_tokens[prompt_tokens_offset]; + + prompt_tokens_offset += 1; + + var start_time: i64 = 0; + var total_time: i64 = 0; + var next_token: usize = 1; + var n_steps: usize = 0; + + for (0..self.n_steps) |pos| { + if (pos > 0) { + start_time = std.time.milliTimestamp(); + } + + try self.transformer.forward(current_token, pos); + + if (start_time > 0) { + total_time += std.time.milliTimestamp() - start_time; + } + + if (prompt_tokens_offset < self.prompt_tokens.len) { + next_token = self.prompt_tokens[prompt_tokens_offset]; + prompt_tokens_offset += 1; + } else { + next_token = self.sampler.sample(self.transformer.logits); + } + + n_steps += 1; + + if (next_token == 1) { + break; // the BOS (=1) token delimits sequences + } + + const word = self.tokenizer.decode(current_token, next_token); + + try lib.print(word, writer); + + current_token = next_token; + } + + if (total_time > 0 and self.timer) { + const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(n_steps)); + + try writer.print("\n\nachieved: {d:.3} tok/s\n", .{@as(f32, 1000 / average_time)}); + } else { + try writer.print("\n", .{}); + } +} + +test "generate tiny story" { + var output = std.ArrayList(u8).init(std.testing.allocator); + + defer output.deinit(); + + var arg_iterator = try std.process.argsWithAllocator(std.testing.allocator); + + defer arg_iterator.deinit(); + + const cli = Cli{ + .checkpoint_path = "stories260K.bin", + .temperature = 1, + .top_p = 0.9, + .random_seed = 42, + .n_steps = 10, + .prompt = "There was", + .tokenizer_path = "tok512.bin", + .mmap = false, + .timer = false, + .arg_iterator = arg_iterator, + }; + + var generator = try Self.init(std.testing.allocator, &cli); + + defer generator.deinit(); + + try generator.generate(output.writer()); + + try std.testing.expectEqualStrings("There was a good room\n", output.items); +} diff --git a/src/main.zig b/src/main.zig index d6b0dbc..96d7cdf 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,124 +1,21 @@ const std = @import("std"); -const lib = @import("lib.zig"); -const Checkpoint = @import("checkpoint.zig"); const Cli = @import("cli.zig"); -const Sampler = @import("sampler.zig"); -const Tokenizer = @import("tokenizer.zig"); -const Transformer = @import("transformer.zig"); +const Generator = @import("generator.zig"); pub fn main() !void { - var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); - - defer arena.deinit(); - - var cli = try Cli.init(arena.allocator()); - - defer cli.deinit(); - const stdout = std.io.getStdOut().writer(); - try generate(arena.allocator(), cli, stdout); -} - -fn generate(allocator: std.mem.Allocator, cli: Cli, writer: anytype) !void { - const transformer = try Transformer.init(allocator, cli); - - defer transformer.deinit(); - - const vocab_size = transformer.checkpoint.vocab_size; - - var sampler = try Sampler.init(allocator, cli, vocab_size); - - defer sampler.deinit(); - - const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size); - - defer tokenizer.deinit(); - - const prompt_tokens = try tokenizer.encode(allocator, cli.prompt, true, false); - - defer allocator.free(prompt_tokens); - - var prompt_tokens_offset: usize = 0; - - std.debug.assert(prompt_tokens.len > 0); - - var current_token: usize = prompt_tokens[prompt_tokens_offset]; + var cli = try Cli.init(std.heap.page_allocator); - prompt_tokens_offset += 1; - - var start_time: i64 = 0; - var total_time: i64 = 0; - var next_token: usize = 1; - var n_steps: usize = 0; - - for (0..cli.n_steps) |pos| { - if (pos > 0) { - start_time = std.time.milliTimestamp(); - } - - try transformer.forward(current_token, pos); - - if (start_time > 0) { - total_time += std.time.milliTimestamp() - start_time; - } - - if (prompt_tokens_offset < prompt_tokens.len) { - next_token = prompt_tokens[prompt_tokens_offset]; - prompt_tokens_offset += 1; - } else { - next_token = sampler.sample(transformer.logits); - } - - n_steps += 1; - - if (next_token == 1) { - break; // the BOS (=1) token delimits sequences - } - - const word = tokenizer.decode(current_token, next_token); - - try lib.print(word, writer); + defer cli.deinit(); - current_token = next_token; - } + var generator = try Generator.init(std.heap.page_allocator, &cli); - if (total_time > 0 and cli.timer) { - const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(n_steps)); + defer generator.deinit(); - try writer.print("\n\nachieved: {d:.3} tok/s\n", .{@as(f32, 1000 / average_time)}); - } else { - try writer.print("\n", .{}); - } + try generator.generate(stdout); } test { std.testing.refAllDecls(@This()); } - -test "generate tiny story" { - var output = std.ArrayList(u8).init(std.testing.allocator); - - defer output.deinit(); - - var arg_iterator = try std.process.argsWithAllocator(std.testing.allocator); - - defer arg_iterator.deinit(); - - const cli = Cli{ - .checkpoint_path = "stories260K.bin", - .temperature = 1, - .top_p = 0.9, - .random_seed = 42, - .n_steps = 10, - .prompt = "There was", - .tokenizer_path = "tok512.bin", - .mmap = false, - .timer = false, - .arg_iterator = arg_iterator, - }; - - try generate(std.testing.allocator, cli, output.writer()); - - try std.testing.expectEqualStrings("There was a good room\n", output.items); -} diff --git a/src/sampler.zig b/src/sampler.zig index 5bac60b..5ceea5d 100644 --- a/src/sampler.zig +++ b/src/sampler.zig @@ -11,7 +11,7 @@ temperature: f32, top_p: f32, rng_state: u64, -pub fn init(allocator: std.mem.Allocator, cli: Cli, vocab_size: usize) !Self { +pub fn init(allocator: std.mem.Allocator, cli: *const Cli, vocab_size: usize) !Self { return Self{ .allocator = allocator, .probability_index_pairs_buffer = try allocator.alloc(lib.ProbabilityIndexPair, vocab_size), diff --git a/src/transformer.zig b/src/transformer.zig index 88a35d5..5757f83 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -14,7 +14,7 @@ feed_forward: FeedForward, hidden_state: []f32, logits: []f32, -pub fn init(allocator: std.mem.Allocator, cli: Cli) !Self { +pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path); return Self{