diff --git a/src/attention.zig b/src/attention.zig index 48c9e39..d7f22ea 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -3,65 +3,51 @@ const Self = @This(); const std = @import("std"); const lib = @import("lib.zig"); const Checkpoint = @import("checkpoint.zig"); -const matrix = @import("matrix.zig"); allocator: std.mem.Allocator, checkpoint: Checkpoint, -seq_len: usize, +sequence_length: usize, input_buffer: []f32, output_buffer: []f32, scores_buffer: []f32, queries_buffer: []f32, -keys_buffer: []f32, -values_buffer: []f32, key_cache: []f32, value_cache: []f32, -pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, seq_len: usize) !Self { - const dim = checkpoint.dim; - const kv_dim = checkpoint.kv_dim; - const kv_cache_dim = checkpoint.n_layers * seq_len * kv_dim; - - const input_buffer = try allocator.alloc(f32, dim); +pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_length: usize) !Self { + const embedding_size = checkpoint.embedding_size; + const input_buffer = try allocator.alloc(f32, embedding_size); errdefer allocator.free(input_buffer); - const output_buffer = try allocator.alloc(f32, dim); + const output_buffer = try allocator.alloc(f32, embedding_size); errdefer allocator.free(output_buffer); - const scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len); + const scores_buffer = try allocator.alloc(f32, checkpoint.n_query_heads * sequence_length); errdefer allocator.free(scores_buffer); - const queries_buffer = try allocator.alloc(f32, dim); + const queries_buffer = try allocator.alloc(f32, embedding_size); errdefer allocator.free(queries_buffer); - const keys_buffer = try allocator.alloc(f32, kv_dim); - - errdefer allocator.free(keys_buffer); - - const values_buffer = try allocator.alloc(f32, kv_dim); - - errdefer allocator.free(values_buffer); - - const key_cache = try allocator.alloc(f32, kv_cache_dim); + const key_value_size = checkpoint.n_query_head_groups * checkpoint.query_head_size; + const key_value_cache_size = checkpoint.n_layers * sequence_length * key_value_size; + const key_cache = try allocator.alloc(f32, key_value_cache_size); errdefer allocator.free(key_cache); - const value_cache = try allocator.alloc(f32, kv_cache_dim); + const value_cache = try allocator.alloc(f32, key_value_cache_size); return Self{ .allocator = allocator, .checkpoint = checkpoint, - .seq_len = seq_len, + .sequence_length = sequence_length, .input_buffer = input_buffer, .output_buffer = output_buffer, .scores_buffer = scores_buffer, .queries_buffer = queries_buffer, - .keys_buffer = keys_buffer, - .values_buffer = values_buffer, .key_cache = key_cache, .value_cache = value_cache, }; @@ -72,55 +58,41 @@ pub fn deinit(self: *const Self) void { self.allocator.free(self.output_buffer); self.allocator.free(self.scores_buffer); self.allocator.free(self.queries_buffer); - self.allocator.free(self.keys_buffer); - self.allocator.free(self.values_buffer); self.allocator.free(self.key_cache); self.allocator.free(self.value_cache); } pub fn forward(self: *const Self, pos: usize, layer: usize) !void { const checkpoint = self.checkpoint; - const kv_dim = checkpoint.kv_dim; const weights = checkpoint.weights; - try weights.attention_queries_matrix.multiplyVector( + try weights.attention_query_matrices.multiplyVector( layer, self.input_buffer, self.queries_buffer, ); - try weights.attention_keys_matrix.multiplyVector( - layer, - self.input_buffer, - self.keys_buffer, - ); + const query_head_size = checkpoint.query_head_size; + const key_value_size = checkpoint.n_query_head_groups * query_head_size; + const key_value_cache_offset = layer * (self.sequence_length * key_value_size); - try weights.attention_values_matrix.multiplyVector( - layer, - self.input_buffer, - self.values_buffer, - ); + const key_cache = self.key_cache[key_value_cache_offset..]; + const keys_buffer = key_cache[(pos * key_value_size)..][0..key_value_size]; - lib.rope(pos, checkpoint.head_size, self.queries_buffer, self.keys_buffer); + const value_cache = self.value_cache[key_value_cache_offset..]; + const values_buffer = value_cache[(pos * key_value_size)..][0..key_value_size]; - const kv_cache_dim = self.seq_len * kv_dim; - const kv_cache_layer_offset = layer * kv_cache_dim; + try weights.attention_key_matrices.multiplyVector(layer, self.input_buffer, keys_buffer); - @memcpy( - self.key_cache[(kv_cache_layer_offset + pos * kv_dim)..][0..self.keys_buffer.len], - self.keys_buffer, - ); + lib.rope(pos, query_head_size, self.queries_buffer, keys_buffer); - @memcpy( - self.value_cache[(kv_cache_layer_offset + pos * kv_dim)..][0..self.values_buffer.len], - self.values_buffer, - ); + try weights.attention_value_matrices.multiplyVector(layer, self.input_buffer, values_buffer); - for (0..checkpoint.n_heads) |head| { - self.compute_weighted_values(pos, head, kv_cache_layer_offset); + for (0..checkpoint.n_query_heads) |query_head| { + self.compute_weighted_values(pos, query_head, key_cache, value_cache); } - try weights.attention_output_matrix.multiplyVector( + try weights.attention_output_matrices.multiplyVector( layer, self.input_buffer, self.output_buffer, @@ -130,40 +102,41 @@ pub fn forward(self: *const Self, pos: usize, layer: usize) !void { fn compute_weighted_values( self: *const Self, pos: usize, - head: usize, - kv_cache_layer_offset: usize, + query_head: usize, + key_cache: []const f32, + value_cache: []const f32, ) void { @setFloatMode(.Optimized); const checkpoint = self.checkpoint; - const kv_dim = checkpoint.kv_dim; - const head_size = checkpoint.head_size; - - const group = head / checkpoint.n_groups; - const kv_head_offset = group * head_size; - const head_offset = head * head_size; - const query = self.queries_buffer[head_offset..][0..head_size]; - const scores = self.scores_buffer[(head * self.seq_len)..]; + const n_query_head_groups = checkpoint.n_query_head_groups; + const query_head_group = query_head / (checkpoint.n_query_heads / n_query_head_groups); + const query_head_size = checkpoint.query_head_size; + const query_head_offset = query_head * query_head_size; + const query = self.queries_buffer[query_head_offset..][0..query_head_size]; + const key_value_size = n_query_head_groups * query_head_size; + const key_value_head_offset = query_head_group * query_head_size; + const scores = self.scores_buffer[(query_head * self.sequence_length)..]; for (0..(pos + 1)) |prev_pos| { - const kv_cache_head_offset = kv_cache_layer_offset + prev_pos * kv_dim + kv_head_offset; - const key = self.key_cache[kv_cache_head_offset..][0..head_size]; + const key_value_cache_offset = prev_pos * key_value_size + key_value_head_offset; + const key = key_cache[key_value_cache_offset..][0..query_head_size]; - scores[prev_pos] = lib.dot(query, key) / checkpoint.head_size_sqrt; + scores[prev_pos] = lib.dot(query, key) / checkpoint.query_head_size_sqrt; } lib.softmax(scores[0..(pos + 1)]); - const weighted_values = self.input_buffer[head_offset..][0..head_size]; + const weighted_values = self.input_buffer[query_head_offset..][0..query_head_size]; @memset(weighted_values, 0); for (0..(pos + 1)) |prev_pos| { - const kv_cache_head_offset = kv_cache_layer_offset + prev_pos * kv_dim + kv_head_offset; - const value = self.value_cache[kv_cache_head_offset..]; + const key_value_cache_offset = prev_pos * key_value_size + key_value_head_offset; + const value = value_cache[key_value_cache_offset..]; const weight = scores[prev_pos]; - for (0..head_size) |index| { + for (0..query_head_size) |index| { weighted_values[index] += weight * value[index]; } } diff --git a/src/chat.zig b/src/chat.zig index d2e7fa0..b89a1fc 100644 --- a/src/chat.zig +++ b/src/chat.zig @@ -11,7 +11,6 @@ allocator: std.mem.Allocator, transformer: Transformer, tokenizer: Tokenizer, sampler: Sampler, -n_steps: usize, user_prompt: []const u8, system_prompt: []const u8, @@ -34,7 +33,6 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { .transformer = transformer, .tokenizer = tokenizer, .sampler = sampler, - .n_steps = cli.n_steps, .user_prompt = cli.prompt, .system_prompt = cli.system_prompt, }; @@ -69,7 +67,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void { allocator.free(prompt_tokens); }; - for (0..self.n_steps) |pos| { + for (0..self.transformer.sequence_length) |pos| { try self.transformer.forward(token, pos); if (token == bos_token and user_turn) { @@ -131,7 +129,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void { user_prompt_tokens_index += 1; if (next_token == 0) { - next_token = self.sampler.sample(self.transformer.logits_vector); + next_token = self.sampler.sample(self.transformer.logits); } if (next_token == eos_token) { diff --git a/src/checkpoint.zig b/src/checkpoint.zig index 9441a5d..9a9b1b4 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -2,38 +2,36 @@ const Self = @This(); const std = @import("std"); const Cli = @import("./cli.zig"); -const Matrix = @import("./matrix.zig"); -const Vector = @import("./vector.zig"); +const MatrixArray = @import("./matrix_array.zig"); +const VectorArray = @import("./vector_array.zig"); allocator: std.mem.Allocator, mmap: bool, -dim: usize, -hidden_dim: usize, + +embedding_size: usize, +intermediate_size: usize, n_layers: usize, -n_heads: usize, -n_kv_heads: usize, +n_query_heads: usize, +n_query_head_groups: usize, vocab_size: usize, -kv_dim: usize, -head_size: usize, -head_size_sqrt: f32, -n_groups: usize, - -weights: struct { - token_embedding_vector: Vector, - - attention_norm_vector: Vector, - attention_queries_matrix: Matrix, - attention_keys_matrix: Matrix, - attention_values_matrix: Matrix, - attention_output_matrix: Matrix, +max_sequence_length: usize, - feed_forward_norm_vector: Vector, - feed_forward_hidden_matrix: Matrix, - feed_forward_output_matrix: Matrix, - feed_forward_residual_matrix: Matrix, +query_head_size: usize, +query_head_size_sqrt: f32, - final_norm_vector: Vector, - classifier_matrix: Matrix, +weights: struct { + embedding_vectors: VectorArray, + attention_norm_vectors: VectorArray, + attention_query_matrices: MatrixArray, + attention_key_matrices: MatrixArray, + attention_value_matrices: MatrixArray, + attention_output_matrices: MatrixArray, + feed_forward_norm_vectors: VectorArray, + feed_forward_hidden_matrices: MatrixArray, + feed_forward_output_matrices: MatrixArray, + feed_forward_scaling_matrices: MatrixArray, + final_norm_vector: []const f32, + classifier_matrices: MatrixArray, }, data: []align(std.mem.page_size) const u8, @@ -52,153 +50,152 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { const config_data: [*]const i32 = @alignCast(@ptrCast(data[0..28])); - const signed_vocab_size: i32 = config_data[5]; - const dim: usize = @intCast(config_data[0]); - const hidden_dim: usize = @intCast(config_data[1]); + const embedding_size: usize = @intCast(config_data[0]); + const intermediate_size: usize = @intCast(config_data[1]); const n_layers: usize = @intCast(config_data[2]); - const n_heads: usize = @intCast(config_data[3]); - const n_kv_heads: usize = @intCast(config_data[4]); + const n_query_heads: usize = @intCast(config_data[3]); + const n_query_head_groups: usize = @intCast(config_data[4]); + const signed_vocab_size: i32 = config_data[5]; const vocab_size: usize = std.math.absCast(signed_vocab_size); - const kv_dim: usize = (dim * n_kv_heads) / n_heads; - const head_size: usize = dim / n_heads; - const head_size_sqrt: f32 = std.math.sqrt(@as(f32, @floatFromInt(head_size))); - const n_groups: usize = n_heads / n_kv_heads; + const max_sequence_length: usize = @intCast(config_data[6]); + + const query_head_size: usize = embedding_size / n_query_heads; + const query_head_size_sqrt: f32 = std.math.sqrt(@as(f32, @floatFromInt(query_head_size))); var weights_data: [*]const f32 = @alignCast(@ptrCast(data[28..])); - const token_embedding_vector = Vector.init( - dim, - readFloatSlice(&weights_data, vocab_size * dim), + const embedding_vectors = VectorArray.init( + embedding_size, + readFloatSlice(&weights_data, vocab_size * embedding_size), ); - const attention_norm_vector = Vector.init( - dim, - readFloatSlice(&weights_data, n_layers * dim), + const attention_norm_vectors = VectorArray.init( + embedding_size, + readFloatSlice(&weights_data, n_layers * embedding_size), ); - const attention_queries_matrix = try Matrix.init( + const attention_query_matrices = try MatrixArray.init( allocator, - dim, - dim, - readFloatSlice(&weights_data, n_layers * (dim * dim)), + embedding_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (embedding_size * embedding_size)), cli.multithreading, ); - errdefer attention_queries_matrix.deinit(); + errdefer attention_query_matrices.deinit(); + + const key_value_size: usize = query_head_size * n_query_head_groups; - const attention_keys_matrix = try Matrix.init( + const attention_key_matrices = try MatrixArray.init( allocator, - kv_dim, - dim, - readFloatSlice(&weights_data, n_layers * (kv_dim * dim)), + key_value_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (key_value_size * embedding_size)), cli.multithreading, ); - errdefer attention_keys_matrix.deinit(); + errdefer attention_key_matrices.deinit(); - const attention_values_matrix = try Matrix.init( + const attention_value_matrices = try MatrixArray.init( allocator, - kv_dim, - dim, - readFloatSlice(&weights_data, n_layers * (kv_dim * dim)), + key_value_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (key_value_size * embedding_size)), cli.multithreading, ); - errdefer attention_values_matrix.deinit(); + errdefer attention_value_matrices.deinit(); - const attention_output_matrix = try Matrix.init( + const attention_output_matrices = try MatrixArray.init( allocator, - dim, - dim, - readFloatSlice(&weights_data, n_layers * (dim * dim)), + embedding_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (embedding_size * embedding_size)), cli.multithreading, ); - errdefer attention_output_matrix.deinit(); + errdefer attention_output_matrices.deinit(); - const feed_forward_norm_vector = Vector.init( - dim, - readFloatSlice(&weights_data, n_layers * dim), + const feed_forward_norm_vectors = VectorArray.init( + embedding_size, + readFloatSlice(&weights_data, n_layers * embedding_size), ); - const feed_forward_hidden_matrix = try Matrix.init( + const feed_forward_hidden_matrices = try MatrixArray.init( allocator, - hidden_dim, - dim, - readFloatSlice(&weights_data, n_layers * (hidden_dim * dim)), + intermediate_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (intermediate_size * embedding_size)), cli.multithreading, ); - errdefer feed_forward_hidden_matrix.deinit(); + errdefer feed_forward_hidden_matrices.deinit(); - const feed_forward_output_matrix = try Matrix.init( + const feed_forward_output_matrices = try MatrixArray.init( allocator, - dim, - hidden_dim, - readFloatSlice(&weights_data, n_layers * (dim * hidden_dim)), + embedding_size, + intermediate_size, + readFloatSlice(&weights_data, n_layers * (embedding_size * intermediate_size)), cli.multithreading, ); - errdefer feed_forward_output_matrix.deinit(); + errdefer feed_forward_output_matrices.deinit(); - const feed_forward_residual_matrix = try Matrix.init( + const feed_forward_scaling_matrices = try MatrixArray.init( allocator, - hidden_dim, - dim, - readFloatSlice(&weights_data, n_layers * (hidden_dim * dim)), + intermediate_size, + embedding_size, + readFloatSlice(&weights_data, n_layers * (intermediate_size * embedding_size)), cli.multithreading, ); - errdefer feed_forward_residual_matrix.deinit(); + errdefer feed_forward_scaling_matrices.deinit(); - const final_norm_vector = Vector.init(dim, readFloatSlice(&weights_data, dim)); - const seq_len: usize = @intCast(config_data[6]); + const final_norm_vector = readFloatSlice(&weights_data, embedding_size); - _ = readFloatSlice(&weights_data, seq_len * head_size / 2); - _ = readFloatSlice(&weights_data, seq_len * head_size / 2); + _ = readFloatSlice(&weights_data, max_sequence_length * query_head_size / 2); + _ = readFloatSlice(&weights_data, max_sequence_length * query_head_size / 2); // https://github.com/karpathy/llama2.c/commit/c3e0d73bd294e1f5e4d17425fac09aaec536400d - const classifier_matrix = try Matrix.init( + const classifier_matrices = try MatrixArray.init( allocator, vocab_size, - dim, + embedding_size, if (signed_vocab_size > 0) - token_embedding_vector.data + embedding_vectors.data else - readFloatSlice(&weights_data, vocab_size * dim), + readFloatSlice(&weights_data, vocab_size * embedding_size), cli.multithreading, ); return Self{ .allocator = allocator, .mmap = cli.mmap, - .dim = dim, - .hidden_dim = hidden_dim, + + .embedding_size = embedding_size, + .intermediate_size = intermediate_size, .n_layers = n_layers, - .n_heads = n_heads, - .n_kv_heads = n_kv_heads, + .n_query_heads = n_query_heads, + .n_query_head_groups = n_query_head_groups, .vocab_size = vocab_size, - .kv_dim = kv_dim, - .head_size = head_size, - .head_size_sqrt = head_size_sqrt, - .n_groups = n_groups, + .max_sequence_length = max_sequence_length, - .weights = .{ - .token_embedding_vector = token_embedding_vector, - - .attention_norm_vector = attention_norm_vector, - .attention_queries_matrix = attention_queries_matrix, - .attention_keys_matrix = attention_keys_matrix, - .attention_values_matrix = attention_values_matrix, - .attention_output_matrix = attention_output_matrix, - - .feed_forward_norm_vector = feed_forward_norm_vector, - .feed_forward_hidden_matrix = feed_forward_hidden_matrix, - .feed_forward_output_matrix = feed_forward_output_matrix, - .feed_forward_residual_matrix = feed_forward_residual_matrix, + .query_head_size = query_head_size, + .query_head_size_sqrt = query_head_size_sqrt, + .weights = .{ + .embedding_vectors = embedding_vectors, + .attention_norm_vectors = attention_norm_vectors, + .attention_query_matrices = attention_query_matrices, + .attention_key_matrices = attention_key_matrices, + .attention_value_matrices = attention_value_matrices, + .attention_output_matrices = attention_output_matrices, + .feed_forward_norm_vectors = feed_forward_norm_vectors, + .feed_forward_hidden_matrices = feed_forward_hidden_matrices, + .feed_forward_output_matrices = feed_forward_output_matrices, + .feed_forward_scaling_matrices = feed_forward_scaling_matrices, .final_norm_vector = final_norm_vector, - .classifier_matrix = classifier_matrix, + .classifier_matrices = classifier_matrices, }, .data = data, @@ -206,14 +203,14 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { } pub fn deinit(self: *const Self) void { - self.weights.attention_queries_matrix.deinit(); - self.weights.attention_keys_matrix.deinit(); - self.weights.attention_values_matrix.deinit(); - self.weights.attention_output_matrix.deinit(); - self.weights.feed_forward_hidden_matrix.deinit(); - self.weights.feed_forward_output_matrix.deinit(); - self.weights.feed_forward_residual_matrix.deinit(); - self.weights.classifier_matrix.deinit(); + self.weights.attention_query_matrices.deinit(); + self.weights.attention_key_matrices.deinit(); + self.weights.attention_value_matrices.deinit(); + self.weights.attention_output_matrices.deinit(); + self.weights.feed_forward_hidden_matrices.deinit(); + self.weights.feed_forward_output_matrices.deinit(); + self.weights.feed_forward_scaling_matrices.deinit(); + self.weights.classifier_matrices.deinit(); if (self.mmap) { std.os.munmap(self.data); diff --git a/src/cli.zig b/src/cli.zig index 68a032f..8f6d3c7 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -112,7 +112,7 @@ pub fn init(allocator: std.mem.Allocator) !Self { .temperature = @max(@min(temperature orelse 1, 1), 0), .top_p = @max(@min(top_p orelse 0.9, 1), 0), .random_seed = random_seed orelse @intCast(std.time.milliTimestamp()), - .n_steps = @max(n_steps orelse 256, 1), + .n_steps = n_steps orelse 0, .prompt = prompt orelse "", .tokenizer_path = tokenizer_path orelse "tokenizer.bin", .chat = if (mode) |arg| std.mem.eql(u8, arg, "chat") else false, @@ -137,7 +137,7 @@ fn exit() !noreturn { try stderr.print(" -t temperature = 1\n", .{}); try stderr.print(" -p top_p = 0.9; 1 == off\n", .{}); try stderr.print(" -s random_seed = milli_timestamp\n", .{}); - try stderr.print(" -n n_steps = 256\n", .{}); + try stderr.print(" -n n_steps = max_sequence_length\n", .{}); try stderr.print(" -i prompt = \"\"\n", .{}); try stderr.print(" -z tokenizer_path = \"tokenizer.bin\"\n", .{}); try stderr.print(" -m mode = \"generate\"; (alt. \"chat\")\n", .{}); diff --git a/src/feed_forward.zig b/src/feed_forward.zig index 27ec554..a4821f1 100644 --- a/src/feed_forward.zig +++ b/src/feed_forward.zig @@ -2,39 +2,37 @@ const Self = @This(); const std = @import("std"); const Checkpoint = @import("checkpoint.zig"); -const matrix = @import("matrix.zig"); allocator: std.mem.Allocator, checkpoint: Checkpoint, input_buffer: []f32, hidden_buffer: []f32, -residual_buffer: []f32, +scaling_buffer: []f32, output_buffer: []f32, pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self { - const dim = checkpoint.dim; - const hidden_dim = checkpoint.hidden_dim; - - const input_buffer = try allocator.alloc(f32, dim); + const embedding_size = checkpoint.embedding_size; + const input_buffer = try allocator.alloc(f32, embedding_size); errdefer allocator.free(input_buffer); - const hidden_buffer = try allocator.alloc(f32, hidden_dim); + const intermediate_size = checkpoint.intermediate_size; + const hidden_buffer = try allocator.alloc(f32, intermediate_size); errdefer allocator.free(hidden_buffer); - const residual_buffer = try allocator.alloc(f32, hidden_dim); + const scaling_buffer = try allocator.alloc(f32, intermediate_size); - errdefer allocator.free(residual_buffer); + errdefer allocator.free(scaling_buffer); - const output_buffer = try allocator.alloc(f32, dim); + const output_buffer = try allocator.alloc(f32, embedding_size); return Self{ .allocator = allocator, .checkpoint = checkpoint, .input_buffer = input_buffer, .hidden_buffer = hidden_buffer, - .residual_buffer = residual_buffer, + .scaling_buffer = scaling_buffer, .output_buffer = output_buffer, }; } @@ -42,7 +40,7 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self { pub fn deinit(self: *const Self) void { self.allocator.free(self.input_buffer); self.allocator.free(self.hidden_buffer); - self.allocator.free(self.residual_buffer); + self.allocator.free(self.scaling_buffer); self.allocator.free(self.output_buffer); } @@ -50,26 +48,25 @@ pub fn forward(self: *const Self, layer: usize) !void { @setFloatMode(.Optimized); const checkpoint = self.checkpoint; - const hidden_dim = checkpoint.hidden_dim; const weights = checkpoint.weights; - try weights.feed_forward_hidden_matrix.multiplyVector( + try weights.feed_forward_hidden_matrices.multiplyVector( layer, self.input_buffer, self.hidden_buffer, ); - try weights.feed_forward_residual_matrix.multiplyVector( + try weights.feed_forward_scaling_matrices.multiplyVector( layer, self.input_buffer, - self.residual_buffer, + self.scaling_buffer, ); - for (0..hidden_dim) |index| { - self.hidden_buffer[index] = silu(self.hidden_buffer[index]) * self.residual_buffer[index]; + for (0..checkpoint.intermediate_size) |index| { + self.hidden_buffer[index] = silu(self.hidden_buffer[index]) * self.scaling_buffer[index]; } - try weights.feed_forward_output_matrix.multiplyVector( + try weights.feed_forward_output_matrices.multiplyVector( layer, self.hidden_buffer, self.output_buffer, diff --git a/src/generator.zig b/src/generator.zig index 1c39e14..049b29d 100644 --- a/src/generator.zig +++ b/src/generator.zig @@ -12,7 +12,6 @@ transformer: Transformer, tokenizer: Tokenizer, sampler: Sampler, prompt_tokens: []usize, -n_steps: usize, timer: bool, pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { @@ -37,7 +36,6 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { .tokenizer = tokenizer, .sampler = sampler, .prompt_tokens = prompt_tokens, - .n_steps = cli.n_steps, .timer = cli.timer, }; } @@ -60,7 +58,7 @@ pub fn generate(self: *Self, writer: anytype) !void { var start_time: i64 = 0; var total_time: i64 = 0; - for (0..self.n_steps) |pos| { + for (0..self.transformer.sequence_length) |pos| { if (pos > 0) { n_timed_steps += 1; start_time = std.time.milliTimestamp(); @@ -76,7 +74,7 @@ pub fn generate(self: *Self, writer: anytype) !void { next_token = self.prompt_tokens[prompt_tokens_index]; prompt_tokens_index += 1; } else { - next_token = self.sampler.sample(self.transformer.logits_vector); + next_token = self.sampler.sample(self.transformer.logits); } if (next_token == bos_token or next_token == eos_token) { diff --git a/src/lib/rmsnorm.zig b/src/lib/rmsnorm.zig index fe6185a..75866ec 100644 --- a/src/lib/rmsnorm.zig +++ b/src/lib/rmsnorm.zig @@ -1,23 +1,23 @@ const std = @import("std"); // Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467) -pub fn rmsnorm(input_vector: []const f32, weight_vector: []const f32, output_vector: []f32) void { +pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void { @setFloatMode(.Optimized); - std.debug.assert(output_vector.len == input_vector.len); - std.debug.assert(output_vector.len == weight_vector.len); + std.debug.assert(output.len == input.len); + std.debug.assert(output.len == weight.len); var rms_scaling_factor: f32 = 0; - for (input_vector) |element| { + for (input) |element| { rms_scaling_factor += element * element; } - rms_scaling_factor /= @floatFromInt(input_vector.len); + rms_scaling_factor /= @floatFromInt(input.len); rms_scaling_factor += 1e-5; rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor); - for (output_vector, 0..) |*element, index| { - element.* = weight_vector[index] * rms_scaling_factor * input_vector[index]; + for (output, 0..) |*element, index| { + element.* = weight[index] * rms_scaling_factor * input[index]; } } diff --git a/src/lib/rope.zig b/src/lib/rope.zig index 11acb66..7ee0194 100644 --- a/src/lib/rope.zig +++ b/src/lib/rope.zig @@ -3,7 +3,7 @@ const std = @import("std"); // RoFormer: Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/abs/2104.09864) pub fn rope( pos: usize, - head_size: usize, + query_head_size: usize, queries_buffer: []f32, keys_buffer: []f32, ) void { @@ -14,10 +14,10 @@ pub fn rope( var index: usize = 0; while (index < queries_buffer.len) : (index += 2) { - const head_index: f32 = @floatFromInt(index % head_size); + const query_head: f32 = @floatFromInt(index % query_head_size); const frequency: f32 = - 1 / std.math.pow(f32, 10000, head_index / @as(f32, @floatFromInt(head_size))); + 1 / std.math.pow(f32, 10000, query_head / @as(f32, @floatFromInt(query_head_size))); const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(pos)) * frequency; const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor); diff --git a/src/matrix.zig b/src/matrix_array.zig similarity index 95% rename from src/matrix.zig rename to src/matrix_array.zig index d8070e2..1f1f585 100644 --- a/src/matrix.zig +++ b/src/matrix_array.zig @@ -43,9 +43,9 @@ pub fn init( row_major_data: []const f32, multithreading: bool, ) !Self { - const matrix_dim = m_rows * n_cols; + const matrix_size = m_rows * n_cols; - std.debug.assert(row_major_data.len % matrix_dim == 0); + std.debug.assert(row_major_data.len % matrix_size == 0); const n_worker_threads = if (!multithreading or build_options.accelerate or build_options.metal) 0 @@ -82,8 +82,8 @@ pub fn multiplyVector( std.debug.assert(input_vector.len == n_cols); std.debug.assert(output_vector.len == m_rows); - const matrix_dim = m_rows * n_cols; - const row_major_data = self.row_major_data[(matrix_index * matrix_dim)..][0..matrix_dim]; + const matrix_size = m_rows * n_cols; + const row_major_data = self.row_major_data[(matrix_index * matrix_size)..][0..matrix_size]; if (build_options.accelerate) { matvecmulAccelerate( diff --git a/src/sampler.zig b/src/sampler.zig index 30b7a38..988dda7 100644 --- a/src/sampler.zig +++ b/src/sampler.zig @@ -2,7 +2,6 @@ const Self = @This(); const std = @import("std"); const lib = @import("lib.zig"); -const Checkpoint = @import("checkpoint.zig"); const Cli = @import("cli.zig"); allocator: std.mem.Allocator, diff --git a/src/transformer.zig b/src/transformer.zig index 2f12de2..18ea36e 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -9,17 +9,19 @@ const FeedForward = @import("feed_forward.zig"); allocator: std.mem.Allocator, checkpoint: Checkpoint, +sequence_length: usize, attention: Attention, feed_forward: FeedForward, -hidden_state_vector: []f32, -logits_vector: []f32, +hidden_state: []f32, +logits: []f32, pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { const checkpoint = try Checkpoint.init(allocator, cli); errdefer checkpoint.deinit(); - const attention = try Attention.init(allocator, checkpoint, cli.n_steps); + const sequence_length = if (cli.n_steps == 0) checkpoint.max_sequence_length else cli.n_steps; + const attention = try Attention.init(allocator, checkpoint, sequence_length); errdefer attention.deinit(); @@ -27,19 +29,20 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { errdefer feed_forward.deinit(); - const hidden_state_vector = try allocator.alloc(f32, checkpoint.dim); + const hidden_state = try allocator.alloc(f32, checkpoint.embedding_size); - errdefer allocator.free(hidden_state_vector); + errdefer allocator.free(hidden_state); - const logits_vector = try allocator.alloc(f32, checkpoint.vocab_size); + const logits = try allocator.alloc(f32, checkpoint.vocab_size); return Self{ .allocator = allocator, .checkpoint = checkpoint, + .sequence_length = sequence_length, .attention = attention, .feed_forward = feed_forward, - .hidden_state_vector = hidden_state_vector, - .logits_vector = logits_vector, + .hidden_state = hidden_state, + .logits = logits, }; } @@ -47,43 +50,43 @@ pub fn deinit(self: *const Self) void { self.checkpoint.deinit(); self.attention.deinit(); self.feed_forward.deinit(); - self.allocator.free(self.hidden_state_vector); - self.allocator.free(self.logits_vector); + self.allocator.free(self.hidden_state); + self.allocator.free(self.logits); } pub fn forward(self: *const Self, token: usize, pos: usize) !void { const checkpoint = self.checkpoint; const weights = checkpoint.weights; - @memcpy(self.hidden_state_vector, weights.token_embedding_vector.at(token)); + @memcpy(self.hidden_state, weights.embedding_vectors.at(token)); for (0..checkpoint.n_layers) |layer| { lib.rmsnorm( - self.hidden_state_vector, - weights.attention_norm_vector.at(layer), + self.hidden_state, + weights.attention_norm_vectors.at(layer), self.attention.input_buffer, ); try self.attention.forward(pos, layer); - lib.add(self.hidden_state_vector, self.attention.output_buffer); + lib.add(self.hidden_state, self.attention.output_buffer); lib.rmsnorm( - self.hidden_state_vector, - weights.feed_forward_norm_vector.at(layer), + self.hidden_state, + weights.feed_forward_norm_vectors.at(layer), self.feed_forward.input_buffer, ); try self.feed_forward.forward(layer); - lib.add(self.hidden_state_vector, self.feed_forward.output_buffer); + lib.add(self.hidden_state, self.feed_forward.output_buffer); } lib.rmsnorm( - self.hidden_state_vector, - weights.final_norm_vector.at(0), - self.hidden_state_vector, + self.hidden_state, + weights.final_norm_vector, + self.hidden_state, ); - try weights.classifier_matrix.multiplyVector(0, self.hidden_state_vector, self.logits_vector); + try weights.classifier_matrices.multiplyVector(0, self.hidden_state, self.logits); } diff --git a/src/vector.zig b/src/vector.zig deleted file mode 100644 index 8ed1484..0000000 --- a/src/vector.zig +++ /dev/null @@ -1,16 +0,0 @@ -const Self = @This(); - -const std = @import("std"); - -dim: usize, -data: []const f32, - -pub fn init(dim: usize, data: []const f32) Self { - std.debug.assert(data.len % dim == 0); - - return Self{ .dim = dim, .data = data }; -} - -pub fn at(self: *const Self, index: usize) []const f32 { - return self.data[(index * self.dim)..][0..self.dim]; -} diff --git a/src/vector_array.zig b/src/vector_array.zig new file mode 100644 index 0000000..d876fd3 --- /dev/null +++ b/src/vector_array.zig @@ -0,0 +1,16 @@ +const Self = @This(); + +const std = @import("std"); + +size: usize, +data: []const f32, + +pub fn init(size: usize, data: []const f32) Self { + std.debug.assert(data.len % size == 0); + + return Self{ .size = size, .data = data }; +} + +pub fn at(self: *const Self, index: usize) []const f32 { + return self.data[(index * self.size)..][0..self.size]; +}