Skip to content

Commit

Permalink
Introduce generator
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 26, 2023
1 parent f00461e commit e7542cd
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 111 deletions.
133 changes: 133 additions & 0 deletions src/generator.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
const Self = @This();

const std = @import("std");
const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");
const Sampler = @import("sampler.zig");
const Tokenizer = @import("tokenizer.zig");
const Transformer = @import("transformer.zig");

allocator: std.mem.Allocator,
transformer: Transformer,
tokenizer: Tokenizer,
sampler: Sampler,
prompt_tokens: []usize,
n_steps: usize,
timer: bool,

pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
const transformer = try Transformer.init(allocator, cli);

errdefer transformer.deinit();

const vocab_size = transformer.checkpoint.vocab_size;
const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size);

errdefer tokenizer.deinit();

const sampler = try Sampler.init(allocator, cli, vocab_size);

errdefer sampler.deinit();

return Self{
.allocator = allocator,
.transformer = transformer,
.tokenizer = tokenizer,
.sampler = sampler,
.prompt_tokens = try tokenizer.encode(allocator, cli.prompt, true, false),
.n_steps = cli.n_steps,
.timer = cli.timer,
};
}

pub fn deinit(self: *const Self) void {
self.transformer.deinit();
self.tokenizer.deinit();
self.sampler.deinit();
self.allocator.free(self.prompt_tokens);
}

pub fn generate(self: *Self, writer: anytype) !void {
std.debug.assert(self.prompt_tokens.len > 0);

var prompt_tokens_offset: usize = 0;
var current_token: usize = self.prompt_tokens[prompt_tokens_offset];

prompt_tokens_offset += 1;

var start_time: i64 = 0;
var total_time: i64 = 0;
var next_token: usize = 1;
var n_steps: usize = 0;

for (0..self.n_steps) |pos| {
if (pos > 0) {
start_time = std.time.milliTimestamp();
}

try self.transformer.forward(current_token, pos);

if (start_time > 0) {
total_time += std.time.milliTimestamp() - start_time;
}

if (prompt_tokens_offset < self.prompt_tokens.len) {
next_token = self.prompt_tokens[prompt_tokens_offset];
prompt_tokens_offset += 1;
} else {
next_token = self.sampler.sample(self.transformer.logits);
}

n_steps += 1;

if (next_token == 1) {
break; // the BOS (=1) token delimits sequences
}

const word = self.tokenizer.decode(current_token, next_token);

try lib.print(word, writer);

current_token = next_token;
}

if (total_time > 0 and self.timer) {
const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(n_steps));

try writer.print("\n\nachieved: {d:.3} tok/s\n", .{@as(f32, 1000 / average_time)});
} else {
try writer.print("\n", .{});
}
}

test "generate tiny story" {
var output = std.ArrayList(u8).init(std.testing.allocator);

defer output.deinit();

var arg_iterator = try std.process.argsWithAllocator(std.testing.allocator);

defer arg_iterator.deinit();

const cli = Cli{
.checkpoint_path = "stories260K.bin",
.temperature = 1,
.top_p = 0.9,
.random_seed = 42,
.n_steps = 10,
.prompt = "There was",
.tokenizer_path = "tok512.bin",
.mmap = false,
.timer = false,
.arg_iterator = arg_iterator,
};

var generator = try Self.init(std.testing.allocator, &cli);

defer generator.deinit();

try generator.generate(output.writer());

try std.testing.expectEqualStrings("There was a good room\n", output.items);
}
115 changes: 6 additions & 109 deletions src/main.zig
Original file line number Diff line number Diff line change
@@ -1,124 +1,21 @@
const std = @import("std");
const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");
const Sampler = @import("sampler.zig");
const Tokenizer = @import("tokenizer.zig");
const Transformer = @import("transformer.zig");
const Generator = @import("generator.zig");

pub fn main() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);

defer arena.deinit();

var cli = try Cli.init(arena.allocator());

defer cli.deinit();

const stdout = std.io.getStdOut().writer();

try generate(arena.allocator(), cli, stdout);
}

fn generate(allocator: std.mem.Allocator, cli: Cli, writer: anytype) !void {
const transformer = try Transformer.init(allocator, cli);

defer transformer.deinit();

const vocab_size = transformer.checkpoint.vocab_size;

var sampler = try Sampler.init(allocator, cli, vocab_size);

defer sampler.deinit();

const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size);

defer tokenizer.deinit();

const prompt_tokens = try tokenizer.encode(allocator, cli.prompt, true, false);

defer allocator.free(prompt_tokens);

var prompt_tokens_offset: usize = 0;

std.debug.assert(prompt_tokens.len > 0);

var current_token: usize = prompt_tokens[prompt_tokens_offset];
var cli = try Cli.init(std.heap.page_allocator);

prompt_tokens_offset += 1;

var start_time: i64 = 0;
var total_time: i64 = 0;
var next_token: usize = 1;
var n_steps: usize = 0;

for (0..cli.n_steps) |pos| {
if (pos > 0) {
start_time = std.time.milliTimestamp();
}

try transformer.forward(current_token, pos);

if (start_time > 0) {
total_time += std.time.milliTimestamp() - start_time;
}

if (prompt_tokens_offset < prompt_tokens.len) {
next_token = prompt_tokens[prompt_tokens_offset];
prompt_tokens_offset += 1;
} else {
next_token = sampler.sample(transformer.logits);
}

n_steps += 1;

if (next_token == 1) {
break; // the BOS (=1) token delimits sequences
}

const word = tokenizer.decode(current_token, next_token);

try lib.print(word, writer);
defer cli.deinit();

current_token = next_token;
}
var generator = try Generator.init(std.heap.page_allocator, &cli);

if (total_time > 0 and cli.timer) {
const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(n_steps));
defer generator.deinit();

try writer.print("\n\nachieved: {d:.3} tok/s\n", .{@as(f32, 1000 / average_time)});
} else {
try writer.print("\n", .{});
}
try generator.generate(stdout);
}

test {
std.testing.refAllDecls(@This());
}

test "generate tiny story" {
var output = std.ArrayList(u8).init(std.testing.allocator);

defer output.deinit();

var arg_iterator = try std.process.argsWithAllocator(std.testing.allocator);

defer arg_iterator.deinit();

const cli = Cli{
.checkpoint_path = "stories260K.bin",
.temperature = 1,
.top_p = 0.9,
.random_seed = 42,
.n_steps = 10,
.prompt = "There was",
.tokenizer_path = "tok512.bin",
.mmap = false,
.timer = false,
.arg_iterator = arg_iterator,
};

try generate(std.testing.allocator, cli, output.writer());

try std.testing.expectEqualStrings("There was a good room\n", output.items);
}
2 changes: 1 addition & 1 deletion src/sampler.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ temperature: f32,
top_p: f32,
rng_state: u64,

pub fn init(allocator: std.mem.Allocator, cli: Cli, vocab_size: usize) !Self {
pub fn init(allocator: std.mem.Allocator, cli: *const Cli, vocab_size: usize) !Self {
return Self{
.allocator = allocator,
.probability_index_pairs_buffer = try allocator.alloc(lib.ProbabilityIndexPair, vocab_size),
Expand Down
2 changes: 1 addition & 1 deletion src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ feed_forward: FeedForward,
hidden_state: []f32,
logits: []f32,

pub fn init(allocator: std.mem.Allocator, cli: Cli) !Self {
pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path);

return Self{
Expand Down

0 comments on commit e7542cd

Please sign in to comment.