Skip to content

Commit

Permalink
Better names
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 16, 2023
1 parent 53c3c6b commit 5c3137a
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 127 deletions.
24 changes: 16 additions & 8 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,34 @@ value_cache: Tensor(4),
scores: []f32,

pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_length: usize) !Self {
const head_size: usize = checkpoint.embedding_size / checkpoint.n_heads;
const input_buffer = try Tensor(2).init(allocator, [_]usize{ checkpoint.n_heads, head_size });
const embedding_size = checkpoint.embedding_size;
const n_attention_heads = checkpoint.n_attention_heads;
const head_size: usize = embedding_size / n_attention_heads;
const input_buffer = try Tensor(2).init(allocator, [_]usize{ n_attention_heads, head_size });

errdefer input_buffer.deinit();

const output_buffer = try Tensor(1).init(allocator, [_]usize{checkpoint.embedding_size});
const output_buffer = try Tensor(1).init(allocator, [_]usize{embedding_size});

errdefer output_buffer.deinit();

const query_buffer = try Tensor(2).init(allocator, [_]usize{ checkpoint.n_heads, head_size });
const query_buffer = try Tensor(2).init(allocator, [_]usize{ n_attention_heads, head_size });

errdefer query_buffer.deinit();

const n_layers = checkpoint.n_layers;
const n_attention_query_groups = checkpoint.n_attention_query_groups;

const key_cache = try Tensor(4).init(
allocator,
[_]usize{ checkpoint.n_layers, sequence_length, checkpoint.n_query_groups, head_size },
[_]usize{ n_layers, sequence_length, n_attention_query_groups, head_size },
);

errdefer key_cache.deinit();

const value_cache = try Tensor(4).init(
allocator,
[_]usize{ checkpoint.n_layers, sequence_length, checkpoint.n_query_groups, head_size },
[_]usize{ n_layers, sequence_length, n_attention_query_groups, head_size },
);

errdefer value_cache.deinit();
Expand Down Expand Up @@ -86,7 +91,7 @@ pub fn forward(self: *const Self, layer: usize, position: usize) !void {

self.rope(position, key_buffer);

for (0..self.checkpoint.n_heads) |head| {
for (0..self.checkpoint.n_attention_heads) |head| {
self.gqa(layer, position, head);
}

Expand Down Expand Up @@ -132,7 +137,10 @@ fn gqa(self: *const Self, layer: usize, current_position: usize, head: usize) vo
@setFloatMode(.Optimized);

const query_vector = self.query_buffer.slice(head);
const query_group = head / (self.checkpoint.n_heads / self.checkpoint.n_query_groups);

const query_group =
head / (self.checkpoint.n_attention_heads / self.checkpoint.n_attention_query_groups);

const next_position = current_position + 1;

for (0..next_position) |position| {
Expand Down
6 changes: 3 additions & 3 deletions src/chat.zig
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const Self = @This();

const std = @import("std");
const Cli = @import("cli.zig");
const CLI = @import("cli.zig");
const print = @import("print.zig").print;
const Sampler = @import("sampler.zig");
const Tokenizer = @import("tokenizer.zig");
Expand All @@ -14,7 +14,7 @@ sampler: Sampler,
user_prompt: []const u8,
system_prompt: []const u8,

pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
pub fn init(allocator: std.mem.Allocator, cli: *const CLI) !Self {
const transformer = try Transformer.init(allocator, cli);

errdefer transformer.deinit();
Expand Down Expand Up @@ -129,7 +129,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.logits_buffer.data);
next_token = self.sampler.sample(self.transformer.output_buffer.data);
}

if (next_token == eos_token) {
Expand Down
137 changes: 70 additions & 67 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
const Self = @This();

const std = @import("std");
const Cli = @import("./cli.zig");
const CLI = @import("./cli.zig");
const Tensor = @import("./tensor.zig").Tensor;
const vector = @import("./vector.zig");

allocator: std.mem.Allocator,
embedding_size: usize,
hidden_size: usize,
ffn_hidden_size: usize,
n_layers: usize,
n_heads: usize,
n_query_groups: usize,
n_attention_heads: usize,
n_attention_query_groups: usize,
vocab_size: usize,
max_sequence_length: usize,
shared_classifier_matrix: bool,
shared_output_matrix: bool,

weights: struct {
token_embedding_vectors: Tensor(2),
attention_pre_norm_vectors: Tensor(2),
attention_norm_vectors: Tensor(2),
attention_query_matrices: Tensor(3),
attention_key_matrices: Tensor(3),
attention_value_matrices: Tensor(3),
attention_output_matrices: Tensor(3),
feed_forward_pre_norm_vectors: Tensor(2),
feed_forward_pre_activation_matrices: Tensor(3),
feed_forward_output_matrices: Tensor(3),
feed_forward_gate_matrices: Tensor(3),
classifier_pre_norm_vector: Tensor(1),
classifier_matrix: Tensor(2),
ffn_norm_vectors: Tensor(2),
ffn_gate_matrices: Tensor(3),
ffn_down_matrices: Tensor(3),
ffn_up_matrices: Tensor(3),
output_norm_vector: Tensor(1),
output_matrix: Tensor(2),
},

pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
pub fn init(allocator: std.mem.Allocator, cli: *const CLI) !Self {
const file = try std.fs.cwd().openFile(cli.checkpoint_path, .{});

defer file.close();

const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32));
const hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
const n_layers: usize = @intCast(try file.reader().readIntLittle(i32));
const n_heads: usize = @intCast(try file.reader().readIntLittle(i32));
const n_query_groups: usize = @intCast(try file.reader().readIntLittle(i32));
const n_attention_heads: usize = @intCast(try file.reader().readIntLittle(i32));
const n_attention_query_groups: usize = @intCast(try file.reader().readIntLittle(i32));

// https://github.com/karpathy/llama2.c/blob/35deb5e0fa55f0a257040bcf1624ed8386e63dc7/run.c#L153
const signed_vocab_size = try file.reader().readIntLittle(i32);
const shared_classifier_matrix = signed_vocab_size > 0;
const shared_output_matrix = signed_vocab_size > 0;

const vocab_size: usize = std.math.absCast(signed_vocab_size);
const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32));
Expand All @@ -56,13 +56,13 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
errdefer token_embedding_vectors.deinit();
try token_embedding_vectors.read(file);

const attention_pre_norm_vectors = try Tensor(2).init(
const attention_norm_vectors = try Tensor(2).init(
allocator,
[_]usize{ n_layers, embedding_size },
);

errdefer attention_pre_norm_vectors.deinit();
try attention_pre_norm_vectors.read(file);
errdefer attention_norm_vectors.deinit();
try attention_norm_vectors.read(file);

const attention_query_matrices = try Tensor(3).init(
allocator,
Expand All @@ -72,19 +72,19 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
errdefer attention_query_matrices.deinit();
try attention_query_matrices.read(file);

const head_size: usize = embedding_size / n_heads;
const attention_head_size: usize = embedding_size / n_attention_heads;

const attention_key_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, n_query_groups * head_size, embedding_size },
[_]usize{ n_layers, n_attention_query_groups * attention_head_size, embedding_size },
);

errdefer attention_key_matrices.deinit();
try attention_key_matrices.read(file);

const attention_value_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, n_query_groups * head_size, embedding_size },
[_]usize{ n_layers, n_attention_query_groups * attention_head_size, embedding_size },
);

errdefer attention_value_matrices.deinit();
Expand All @@ -98,100 +98,103 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
errdefer attention_output_matrices.deinit();
try attention_output_matrices.read(file);

const feed_forward_pre_norm_vectors = try Tensor(2).init(
const ffn_norm_vectors = try Tensor(2).init(
allocator,
[_]usize{ n_layers, embedding_size },
);

errdefer feed_forward_pre_norm_vectors.deinit();
try feed_forward_pre_norm_vectors.read(file);
errdefer ffn_norm_vectors.deinit();
try ffn_norm_vectors.read(file);

const feed_forward_pre_activation_matrices = try Tensor(3).init(
const ffn_gate_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, hidden_size, embedding_size },
[_]usize{ n_layers, ffn_hidden_size, embedding_size },
);

errdefer feed_forward_pre_activation_matrices.deinit();
try feed_forward_pre_activation_matrices.read(file);
errdefer ffn_gate_matrices.deinit();
try ffn_gate_matrices.read(file);

const feed_forward_output_matrices = try Tensor(3).init(
const ffn_down_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, embedding_size, hidden_size },
[_]usize{ n_layers, embedding_size, ffn_hidden_size },
);

errdefer feed_forward_output_matrices.deinit();
try feed_forward_output_matrices.read(file);
errdefer ffn_down_matrices.deinit();
try ffn_down_matrices.read(file);

const feed_forward_gate_matrices = try Tensor(3).init(
const ffn_up_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, hidden_size, embedding_size },
[_]usize{ n_layers, ffn_hidden_size, embedding_size },
);

errdefer feed_forward_gate_matrices.deinit();
try feed_forward_gate_matrices.read(file);
errdefer ffn_up_matrices.deinit();
try ffn_up_matrices.read(file);

const classifier_pre_norm_vector = try Tensor(1).init(allocator, [_]usize{embedding_size});
const output_norm_vector = try Tensor(1).init(
allocator,
[_]usize{embedding_size},
);

errdefer classifier_pre_norm_vector.deinit();
try classifier_pre_norm_vector.read(file);
errdefer output_norm_vector.deinit();
try output_norm_vector.read(file);

try file.seekBy(@intCast(max_sequence_length * head_size * @sizeOf(f32)));
try file.seekBy(@intCast(max_sequence_length * attention_head_size * @sizeOf(f32)));

const classifier_matrix = if (shared_classifier_matrix)
const output_matrix = if (shared_output_matrix)
token_embedding_vectors
else
try Tensor(2).init(allocator, [_]usize{ vocab_size, embedding_size });

errdefer if (!shared_classifier_matrix) {
classifier_matrix.deinit();
errdefer if (!shared_output_matrix) {
output_matrix.deinit();
};

if (!shared_classifier_matrix) {
try classifier_matrix.read(file);
if (!shared_output_matrix) {
try output_matrix.read(file);
}

return Self{
.allocator = allocator,
.embedding_size = embedding_size,
.hidden_size = hidden_size,
.ffn_hidden_size = ffn_hidden_size,
.n_layers = n_layers,
.n_heads = n_heads,
.n_query_groups = n_query_groups,
.n_attention_heads = n_attention_heads,
.n_attention_query_groups = n_attention_query_groups,
.vocab_size = vocab_size,
.max_sequence_length = max_sequence_length,
.shared_classifier_matrix = shared_classifier_matrix,
.shared_output_matrix = shared_output_matrix,

.weights = .{
.token_embedding_vectors = token_embedding_vectors,
.attention_pre_norm_vectors = attention_pre_norm_vectors,
.attention_norm_vectors = attention_norm_vectors,
.attention_query_matrices = attention_query_matrices,
.attention_key_matrices = attention_key_matrices,
.attention_value_matrices = attention_value_matrices,
.attention_output_matrices = attention_output_matrices,
.feed_forward_pre_norm_vectors = feed_forward_pre_norm_vectors,
.feed_forward_pre_activation_matrices = feed_forward_pre_activation_matrices,
.feed_forward_output_matrices = feed_forward_output_matrices,
.feed_forward_gate_matrices = feed_forward_gate_matrices,
.classifier_pre_norm_vector = classifier_pre_norm_vector,
.classifier_matrix = classifier_matrix,
.ffn_norm_vectors = ffn_norm_vectors,
.ffn_gate_matrices = ffn_gate_matrices,
.ffn_down_matrices = ffn_down_matrices,
.ffn_up_matrices = ffn_up_matrices,
.output_norm_vector = output_norm_vector,
.output_matrix = output_matrix,
},
};
}

pub fn deinit(self: *const Self) void {
self.weights.token_embedding_vectors.deinit();
self.weights.attention_pre_norm_vectors.deinit();
self.weights.attention_norm_vectors.deinit();
self.weights.attention_query_matrices.deinit();
self.weights.attention_key_matrices.deinit();
self.weights.attention_value_matrices.deinit();
self.weights.attention_output_matrices.deinit();
self.weights.feed_forward_pre_norm_vectors.deinit();
self.weights.feed_forward_pre_activation_matrices.deinit();
self.weights.feed_forward_output_matrices.deinit();
self.weights.feed_forward_gate_matrices.deinit();
self.weights.classifier_pre_norm_vector.deinit();

if (!self.shared_classifier_matrix) {
self.weights.classifier_matrix.deinit();
self.weights.ffn_norm_vectors.deinit();
self.weights.ffn_gate_matrices.deinit();
self.weights.ffn_down_matrices.deinit();
self.weights.ffn_up_matrices.deinit();
self.weights.output_norm_vector.deinit();

if (!self.shared_output_matrix) {
self.weights.output_matrix.deinit();
}
}
Loading

0 comments on commit 5c3137a

Please sign in to comment.