Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 15, 2023
1 parent 59d9564 commit 3136d74
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 58 deletions.
86 changes: 43 additions & 43 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_heads: usize,
n_query_groups: usize,
vocab_size: usize,
max_sequence_length: usize,
shared_final_classifier_matrix: bool,
shared_classifier_matrix: bool,

weights: struct {
token_embedding_vectors: Tensor(2),
Expand All @@ -22,12 +22,12 @@ weights: struct {
attention_key_matrices: Tensor(3),
attention_value_matrices: Tensor(3),
attention_output_matrices: Tensor(3),
ffn_pre_norm_vectors: Tensor(2),
ffn_pre_activation_matrices: Tensor(3),
ffn_output_matrices: Tensor(3),
ffn_gate_matrices: Tensor(3),
final_norm_vector: Tensor(1),
final_classifier_matrix: Tensor(2),
feed_forward_pre_norm_vectors: Tensor(2),
feed_forward_pre_activation_matrices: Tensor(3),
feed_forward_output_matrices: Tensor(3),
feed_forward_gate_matrices: Tensor(3),
classifier_pre_norm_vector: Tensor(1),
classifier_matrix: Tensor(2),
},

pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
Expand All @@ -43,7 +43,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {

// https://github.com/karpathy/llama2.c/blob/35deb5e0fa55f0a257040bcf1624ed8386e63dc7/run.c#L153
const signed_vocab_size = try file.reader().readIntLittle(i32);
const shared_final_classifier_matrix = signed_vocab_size > 0;
const shared_classifier_matrix = signed_vocab_size > 0;

const vocab_size: usize = std.math.absCast(signed_vocab_size);
const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32));
Expand Down Expand Up @@ -98,56 +98,56 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
errdefer attention_output_matrices.deinit();
try attention_output_matrices.read(file);

const ffn_pre_norm_vectors = try Tensor(2).init(
const feed_forward_pre_norm_vectors = try Tensor(2).init(
allocator,
[_]usize{ n_layers, embedding_size },
);

errdefer ffn_pre_norm_vectors.deinit();
try ffn_pre_norm_vectors.read(file);
errdefer feed_forward_pre_norm_vectors.deinit();
try feed_forward_pre_norm_vectors.read(file);

const ffn_pre_activation_matrices = try Tensor(3).init(
const feed_forward_pre_activation_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, hidden_size, embedding_size },
);

errdefer ffn_pre_activation_matrices.deinit();
try ffn_pre_activation_matrices.read(file);
errdefer feed_forward_pre_activation_matrices.deinit();
try feed_forward_pre_activation_matrices.read(file);

const ffn_output_matrices = try Tensor(3).init(
const feed_forward_output_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, embedding_size, hidden_size },
);

errdefer ffn_output_matrices.deinit();
try ffn_output_matrices.read(file);
errdefer feed_forward_output_matrices.deinit();
try feed_forward_output_matrices.read(file);

const ffn_gate_matrices = try Tensor(3).init(
const feed_forward_gate_matrices = try Tensor(3).init(
allocator,
[_]usize{ n_layers, hidden_size, embedding_size },
);

errdefer ffn_gate_matrices.deinit();
try ffn_gate_matrices.read(file);
errdefer feed_forward_gate_matrices.deinit();
try feed_forward_gate_matrices.read(file);

const final_norm_vector = try Tensor(1).init(allocator, [_]usize{embedding_size});
const classifier_pre_norm_vector = try Tensor(1).init(allocator, [_]usize{embedding_size});

errdefer final_norm_vector.deinit();
try final_norm_vector.read(file);
errdefer classifier_pre_norm_vector.deinit();
try classifier_pre_norm_vector.read(file);

try file.seekBy(@intCast(max_sequence_length * head_size * @sizeOf(f32)));

const final_classifier_matrix = if (shared_final_classifier_matrix)
const classifier_matrix = if (shared_classifier_matrix)
token_embedding_vectors
else
try Tensor(2).init(allocator, [_]usize{ vocab_size, embedding_size });

errdefer if (!shared_final_classifier_matrix) {
final_classifier_matrix.deinit();
errdefer if (!shared_classifier_matrix) {
classifier_matrix.deinit();
};

if (!shared_final_classifier_matrix) {
try final_classifier_matrix.read(file);
if (!shared_classifier_matrix) {
try classifier_matrix.read(file);
}

return Self{
Expand All @@ -159,7 +159,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
.n_query_groups = n_query_groups,
.vocab_size = vocab_size,
.max_sequence_length = max_sequence_length,
.shared_final_classifier_matrix = shared_final_classifier_matrix,
.shared_classifier_matrix = shared_classifier_matrix,

.weights = .{
.token_embedding_vectors = token_embedding_vectors,
Expand All @@ -168,12 +168,12 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
.attention_key_matrices = attention_key_matrices,
.attention_value_matrices = attention_value_matrices,
.attention_output_matrices = attention_output_matrices,
.ffn_pre_norm_vectors = ffn_pre_norm_vectors,
.ffn_pre_activation_matrices = ffn_pre_activation_matrices,
.ffn_output_matrices = ffn_output_matrices,
.ffn_gate_matrices = ffn_gate_matrices,
.final_norm_vector = final_norm_vector,
.final_classifier_matrix = final_classifier_matrix,
.feed_forward_pre_norm_vectors = feed_forward_pre_norm_vectors,
.feed_forward_pre_activation_matrices = feed_forward_pre_activation_matrices,
.feed_forward_output_matrices = feed_forward_output_matrices,
.feed_forward_gate_matrices = feed_forward_gate_matrices,
.classifier_pre_norm_vector = classifier_pre_norm_vector,
.classifier_matrix = classifier_matrix,
},
};
}
Expand All @@ -185,13 +185,13 @@ pub fn deinit(self: *const Self) void {
self.weights.attention_key_matrices.deinit();
self.weights.attention_value_matrices.deinit();
self.weights.attention_output_matrices.deinit();
self.weights.ffn_pre_norm_vectors.deinit();
self.weights.ffn_pre_activation_matrices.deinit();
self.weights.ffn_output_matrices.deinit();
self.weights.ffn_gate_matrices.deinit();
self.weights.final_norm_vector.deinit();

if (!self.shared_final_classifier_matrix) {
self.weights.final_classifier_matrix.deinit();
self.weights.feed_forward_pre_norm_vectors.deinit();
self.weights.feed_forward_pre_activation_matrices.deinit();
self.weights.feed_forward_output_matrices.deinit();
self.weights.feed_forward_gate_matrices.deinit();
self.weights.classifier_pre_norm_vector.deinit();

if (!self.shared_classifier_matrix) {
self.weights.classifier_matrix.deinit();
}
}
6 changes: 3 additions & 3 deletions src/ffn.zig → src/feed_forward.zig
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ pub fn forward(self: *const Self, layer: usize) !void {
@setFloatMode(.Optimized);

const weights = self.checkpoint.weights;
const pre_activation_matrix = weights.ffn_pre_activation_matrices.slice(layer);
const gate_matrix = weights.ffn_gate_matrices.slice(layer);
const output_matrix = weights.ffn_output_matrices.slice(layer);
const pre_activation_matrix = weights.feed_forward_pre_activation_matrices.slice(layer);
const gate_matrix = weights.feed_forward_gate_matrices.slice(layer);
const output_matrix = weights.feed_forward_output_matrices.slice(layer);

pre_activation_matrix.multiplyVector(self.input_buffer, self.hidden_buffer);
gate_matrix.multiplyVector(self.input_buffer, self.gate_buffer);
Expand Down
28 changes: 16 additions & 12 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ const std = @import("std");
const Attention = @import("attention.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");
const Ffn = @import("ffn.zig");
const FeedForward = @import("feed_forward.zig");
const Tensor = @import("./tensor.zig").Tensor;
const vector = @import("vector.zig");

allocator: std.mem.Allocator,
checkpoint: Checkpoint,
sequence_length: usize,
attention: Attention,
ffn: Ffn,
feed_forward: FeedForward,
hidden_buffer: Tensor(1),
logits_buffer: Tensor(1),

Expand All @@ -26,9 +26,9 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {

errdefer attention.deinit();

const ffn = try Ffn.init(allocator, checkpoint);
const feed_forward = try FeedForward.init(allocator, checkpoint);

errdefer ffn.deinit();
errdefer feed_forward.deinit();

const hidden_buffer = try Tensor(1).init(allocator, [_]usize{checkpoint.embedding_size});

Expand All @@ -43,7 +43,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
.checkpoint = checkpoint,
.sequence_length = sequence_length,
.attention = attention,
.ffn = ffn,
.feed_forward = feed_forward,
.hidden_buffer = hidden_buffer,
.logits_buffer = logits_buffer,
};
Expand All @@ -52,7 +52,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
pub fn deinit(self: *const Self) void {
self.checkpoint.deinit();
self.attention.deinit();
self.ffn.deinit();
self.feed_forward.deinit();
self.hidden_buffer.deinit();
self.logits_buffer.deinit();
}
Expand All @@ -64,7 +64,7 @@ pub fn forward(self: *const Self, token: usize, position: usize) !void {

for (0..self.checkpoint.n_layers) |layer| {
const attention_pre_norm_vector = weights.attention_pre_norm_vectors.slice(layer);
const ffn_pre_norm_vector = weights.ffn_pre_norm_vectors.slice(layer);
const feed_forward_pre_norm_vector = weights.feed_forward_pre_norm_vectors.slice(layer);

vector.rmsnorm(
self.hidden_buffer.data,
Expand All @@ -76,18 +76,22 @@ pub fn forward(self: *const Self, token: usize, position: usize) !void {

vector.add(self.hidden_buffer.data, self.attention.output_buffer.data);

vector.rmsnorm(self.hidden_buffer.data, ffn_pre_norm_vector.data, self.ffn.input_buffer.data);
vector.rmsnorm(
self.hidden_buffer.data,
feed_forward_pre_norm_vector.data,
self.feed_forward.input_buffer.data,
);

try self.ffn.forward(layer);
try self.feed_forward.forward(layer);

vector.add(self.hidden_buffer.data, self.ffn.output_buffer.data);
vector.add(self.hidden_buffer.data, self.feed_forward.output_buffer.data);
}

vector.rmsnorm(
self.hidden_buffer.data,
weights.final_norm_vector.data,
weights.classifier_pre_norm_vector.data,
self.hidden_buffer.data,
);

weights.final_classifier_matrix.multiplyVector(self.hidden_buffer, self.logits_buffer);
weights.classifier_matrix.multiplyVector(self.hidden_buffer, self.logits_buffer);
}

0 comments on commit 3136d74

Please sign in to comment.