Skip to content

Commit

Permalink
Init checkpoint inside transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 26, 2023
1 parent 3860885 commit f00461e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");

allocator: std.mem.Allocator,
checkpoint: *const Checkpoint,
checkpoint: Checkpoint,
seq_len: usize,
input_buffer: []f32,
output_buffer: []f32,
Expand All @@ -16,7 +16,7 @@ values_buffer: []f32,
key_cache: []f32,
value_cache: []f32,

pub fn init(allocator: std.mem.Allocator, checkpoint: *const Checkpoint, seq_len: usize) !Self {
pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, seq_len: usize) !Self {
const dim = checkpoint.dim;
const kv_dim = checkpoint.kv_dim;
const kv_cache_dim = checkpoint.n_layers * seq_len * kv_dim;
Expand Down
4 changes: 2 additions & 2 deletions src/feed_forward.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");

allocator: std.mem.Allocator,
checkpoint: *const Checkpoint,
checkpoint: Checkpoint,
input_buffer: []f32,
hidden_buffer: []f32,
residual_buffer: []f32,
output_buffer: []f32,

pub fn init(allocator: std.mem.Allocator, checkpoint: *const Checkpoint) !Self {
pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self {
const dim = checkpoint.dim;
const hidden_dim = checkpoint.hidden_dim;

Expand Down
20 changes: 9 additions & 11 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,28 @@ pub fn main() !void {

const stdout = std.io.getStdOut().writer();

try generate(arena.allocator(), &cli, stdout);
try generate(arena.allocator(), cli, stdout);
}

fn generate(allocator: std.mem.Allocator, cli: *const Cli, writer: anytype) !void {
const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path);
fn generate(allocator: std.mem.Allocator, cli: Cli, writer: anytype) !void {
const transformer = try Transformer.init(allocator, cli);

defer checkpoint.deinit();
defer transformer.deinit();

const vocab_size = transformer.checkpoint.vocab_size;

var sampler = try Sampler.init(allocator, cli, checkpoint.vocab_size);
var sampler = try Sampler.init(allocator, cli, vocab_size);

defer sampler.deinit();

const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, checkpoint.vocab_size);
const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size);

defer tokenizer.deinit();

const prompt_tokens = try tokenizer.encode(allocator, cli.prompt, true, false);

defer allocator.free(prompt_tokens);

const transformer = try Transformer.init(allocator, &checkpoint, cli.n_steps);

defer transformer.deinit();

var prompt_tokens_offset: usize = 0;

std.debug.assert(prompt_tokens.len > 0);
Expand Down Expand Up @@ -120,7 +118,7 @@ test "generate tiny story" {
.arg_iterator = arg_iterator,
};

try generate(std.testing.allocator, &cli, output.writer());
try generate(std.testing.allocator, cli, output.writer());

try std.testing.expectEqualStrings("There was a good room\n", output.items);
}
2 changes: 1 addition & 1 deletion src/sampler.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ temperature: f32,
top_p: f32,
rng_state: u64,

pub fn init(allocator: std.mem.Allocator, cli: *const Cli, vocab_size: usize) !Self {
pub fn init(allocator: std.mem.Allocator, cli: Cli, vocab_size: usize) !Self {
return Self{
.allocator = allocator,
.probability_index_pairs_buffer = try allocator.alloc(lib.ProbabilityIndexPair, vocab_size),
Expand Down
20 changes: 12 additions & 8 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,35 @@ const std = @import("std");
const lib = @import("lib.zig");
const Attention = @import("attention.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");
const FeedForward = @import("feed_forward.zig");

allocator: std.mem.Allocator,
checkpoint: *const Checkpoint,
hidden_state: []f32,
logits: []f32,
checkpoint: Checkpoint,
attention: Attention,
feed_forward: FeedForward,
hidden_state: []f32,
logits: []f32,

pub fn init(allocator: std.mem.Allocator, cli: Cli) !Self {
const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path);

pub fn init(allocator: std.mem.Allocator, checkpoint: *const Checkpoint, seq_len: usize) !Self {
return Self{
.allocator = allocator,
.checkpoint = checkpoint,
.attention = try Attention.init(allocator, checkpoint, cli.n_steps),
.feed_forward = try FeedForward.init(allocator, checkpoint),
.hidden_state = try allocator.alloc(f32, checkpoint.dim),
.logits = try allocator.alloc(f32, checkpoint.vocab_size),
.attention = try Attention.init(allocator, checkpoint, seq_len),
.feed_forward = try FeedForward.init(allocator, checkpoint),
};
}

pub fn deinit(self: *const Self) void {
self.allocator.free(self.hidden_state);
self.allocator.free(self.logits);
self.checkpoint.deinit();
self.attention.deinit();
self.feed_forward.deinit();
self.allocator.free(self.hidden_state);
self.allocator.free(self.logits);
}

pub fn forward(self: *const Self, token: usize, pos: usize) !void {
Expand Down

0 comments on commit f00461e

Please sign in to comment.