From e8789a39c77a0357665f05259bffbff9d29af55f Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Mon, 21 Aug 2023 12:57:48 +0200 Subject: [PATCH] Attention (WIP refactoring) --- README.md | 10 ++ src/attention.zig | 158 ++++++++++++++++++++++++++++++ src/checkpoint.zig | 2 + src/feed_forward.zig | 40 ++------ src/lib.zig | 7 ++ src/lib/linear_algebra.zig | 82 ++++++++++++++++ src/lib/rope.zig | 40 ++++++++ src/main.zig | 9 +- src/transformer.zig | 191 ++----------------------------------- src/utils.zig | 48 ---------- 10 files changed, 324 insertions(+), 263 deletions(-) create mode 100644 src/attention.zig create mode 100644 src/lib.zig create mode 100644 src/lib/linear_algebra.zig create mode 100644 src/lib/rope.zig diff --git a/README.md b/README.md index 328484e..d360998 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,13 @@ Some deviations from the original include: - Utilization of slices instead of many-item pointers - For models of 4096+ dimensions, thread pools are utilized to parallelize independent matrix multiplications + +## Papers + +- Standard transformer architecture: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) +- Llama 1: [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) +- Llama 2: [Llama 2: Open Foundation and Fine-Tuned Chat Models](https://arxiv.org/abs/2307.09288) +- Pre-normalization using RMSNorm: [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467) +- SwiGLU activation function: [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202) +- Rotary positional embeddings: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) +- Grouped-query attention: [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245v1) diff --git a/src/attention.zig b/src/attention.zig new file mode 100644 index 0000000..ba20a94 --- /dev/null +++ b/src/attention.zig @@ -0,0 +1,158 @@ +const std = @import("std"); + +const checkpoint = @import("checkpoint.zig"); +const lib = @import("lib.zig"); +const utils = @import("utils.zig"); + +pub const Attention = struct { + const Self = @This(); + + input_buffer: []f32, + output_buffer: []f32, + scores_buffer: []f32, + queries_buffer: []f32, + keys_buffer: []f32, + values_buffer: []f32, + key_cache: []f32, + value_cache: []f32, + + pub fn init(self: *Self, allocator: std.mem.Allocator, config: *const checkpoint.Config) !void { + const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; + + self.input_buffer = try allocator.alloc(f32, config.dim); + self.output_buffer = try allocator.alloc(f32, config.dim); + self.scores_buffer = try allocator.alloc(f32, config.n_heads * config.seq_len); + self.queries_buffer = try allocator.alloc(f32, config.dim); + self.keys_buffer = try allocator.alloc(f32, kv_dim); + self.values_buffer = try allocator.alloc(f32, kv_dim); + self.key_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim); + self.value_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim); + } + + pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void { + allocator.free(self.input_buffer); + allocator.free(self.output_buffer); + allocator.free(self.scores_buffer); + allocator.free(self.queries_buffer); + allocator.free(self.keys_buffer); + allocator.free(self.values_buffer); + allocator.free(self.key_cache); + allocator.free(self.value_cache); + } + + pub fn forward( + self: *const Self, + config: *const checkpoint.Config, + weights: *const checkpoint.Weights, + pos: usize, + layer: usize, + ) !void { + @setFloatMode(.Optimized); + + const dim = config.dim; + const n_heads = config.n_heads; + const seq_len = config.seq_len; + const kv_dim = self.keys_buffer.len; + const query_weights_dim = dim * dim; + const kv_weights_dim = dim * kv_dim; + + try lib.matmul3( + .{ + self.queries_buffer, + self.input_buffer, + weights.query[(layer * query_weights_dim)..][0..query_weights_dim], + }, + .{ + self.keys_buffer, + self.input_buffer, + weights.key[(layer * kv_weights_dim)..][0..kv_weights_dim], + }, + .{ + self.values_buffer, + self.input_buffer, + weights.value[(layer * kv_weights_dim)..][0..kv_weights_dim], + }, + dim >= 4096, + ); + + const head_size = dim / n_heads; + + lib.rope(pos, head_size, self.queries_buffer, self.keys_buffer); + + const kv_cache_dim = seq_len * kv_dim; + const kv_cache_offset = layer * kv_cache_dim; + + @memcpy( + self.key_cache[(kv_cache_offset + pos * kv_dim)..][0..self.keys_buffer.len], + self.keys_buffer, + ); + + @memcpy( + self.value_cache[(kv_cache_offset + pos * kv_dim)..][0..self.values_buffer.len], + self.values_buffer, + ); + + for (0..n_heads) |query_head| { + self.compute_attention(query_head, head_size, config, pos, kv_cache_offset, kv_dim); + } + + lib.matmul( + self.output_buffer, + self.input_buffer, + weights.attention_output[(layer * dim * dim)..][0..(dim * dim)], + ); + } + + fn compute_attention( + self: *const Self, + query_head: usize, + head_size: usize, + config: *const checkpoint.Config, + current_position: usize, + kv_cache_offset: usize, + kv_dim: usize, + ) void { + const n_groups = config.n_heads / config.n_kv_heads; + const head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(head_size))); + const query_head_offset = query_head * head_size; + const query_head_group = query_head / n_groups; + const key_value_head_offset = query_head_group * head_size; + + // get the query vector for this head + const query = self.queries_buffer[query_head_offset..][0..head_size]; + + // attention scores for this head + const attention_weights = self.scores_buffer[(query_head * config.seq_len)..]; + + // iterate over all timesteps, including the current one + for (0..(current_position + 1)) |position| { + // get the key vector for this head and at this timestep + const key = self.key_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..][0..head_size]; + + // calculate the attention score as the dot product of q and k + // save the score to the attention buffer + attention_weights[position] = lib.dotProduct(query, key) / head_size_sqrt; + } + + // softmax the scores to get attention weights, from 0..pos inclusively + utils.softmax(attention_weights[0..(current_position + 1)]); + + // weighted sum of the values, store back into intermediate_buffer + const intermediate_buffer = self.input_buffer[query_head_offset..][0..head_size]; + + @memset(intermediate_buffer, 0); + + for (0..(current_position + 1)) |position| { + // get the value vector for this head and at this timestep + const value = self.value_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..]; + + // get the attention weight for this timestep + const attention_weight = attention_weights[position]; + + // accumulate the weighted value into intermediate_buffer + for (0..head_size) |i| { + intermediate_buffer[i] += attention_weight * value[i]; + } + } + } +}; diff --git a/src/checkpoint.zig b/src/checkpoint.zig index 20e9d0b..ea4f02b 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -84,10 +84,12 @@ pub fn readFile( weights.* = Weights{ .token_embedding = token_embedding, .rms_attention_input = readFloatSlice(&weights_data, config.n_layers * config.dim), + .query = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_heads * head_size)), .key = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_kv_heads * head_size)), .value = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_kv_heads * head_size)), .attention_output = readFloatSlice(&weights_data, config.n_layers * (config.n_heads * head_size) * config.dim), + .rms_ffn_input = readFloatSlice(&weights_data, config.n_layers * config.dim), .ffn_input_to_hidden = readFloatSlice(&weights_data, config.n_layers * config.dim * config.hidden_dim), .ffn_hidden_to_output = readFloatSlice(&weights_data, config.n_layers * config.hidden_dim * config.dim), diff --git a/src/feed_forward.zig b/src/feed_forward.zig index d3f4e3d..9bc190c 100644 --- a/src/feed_forward.zig +++ b/src/feed_forward.zig @@ -1,6 +1,7 @@ const std = @import("std"); const checkpoint = @import("checkpoint.zig"); +const lib = @import("lib.zig"); const utils = @import("utils.zig"); pub const FeedForward = struct { @@ -32,37 +33,31 @@ pub const FeedForward = struct { ) !void { @setFloatMode(.Optimized); - const input_buffer = self.input_buffer; - const hidden_buffer = self.hidden_buffer; - const residual_buffer = self.residual_buffer; - const output_buffer = self.output_buffer; + const dim = self.input_buffer.len; + const hidden_dim = self.hidden_buffer.len; - std.debug.assert(input_buffer.len == output_buffer.len); - std.debug.assert(hidden_buffer.len == residual_buffer.len); - - const dim = input_buffer.len; - const hidden_dim = hidden_buffer.len; const weights_size = dim * hidden_dim; const weights_offset = layer * weights_size; + const input_to_hidden = weights.ffn_input_to_hidden[weights_offset..][0..weights_size]; const input_to_residual = weights.ffn_input_to_residual[weights_offset..][0..weights_size]; const hidden_to_output = weights.ffn_hidden_to_output[weights_offset..][0..weights_size]; - try matmul2( - .{ hidden_buffer, input_buffer, input_to_hidden }, - .{ residual_buffer, input_buffer, input_to_residual }, + try lib.matmul2( + .{ self.hidden_buffer, self.input_buffer, input_to_hidden }, + .{ self.residual_buffer, self.input_buffer, input_to_residual }, dim >= 4096, ); for (0..hidden_dim) |i| { - hidden_buffer[i] = silu(hidden_buffer[i]) * residual_buffer[i]; + self.hidden_buffer[i] = silu(self.hidden_buffer[i]) * self.residual_buffer[i]; } - utils.matmul(output_buffer, hidden_buffer, hidden_to_output); + lib.matmul(self.output_buffer, self.hidden_buffer, hidden_to_output); } }; -// https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html +// GLU Variants Improve Transformer (https://arxiv.org/abs/2002.05202) inline fn silu(x: f32) f32 { return x * sigmoid(x); } @@ -70,18 +65,3 @@ inline fn silu(x: f32) f32 { inline fn sigmoid(x: f32) f32 { return 1 / (1 + @exp(-x)); } - -fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void { - const cpu_count = std.Thread.getCpuCount() catch 1; - - if (multi_threaded and cpu_count > 2) { - const thread_1 = try std.Thread.spawn(.{}, utils.matmul, args_1); - const thread_2 = try std.Thread.spawn(.{}, utils.matmul, args_2); - - thread_1.join(); - thread_2.join(); - } else { - @call(.auto, utils.matmul, args_1); - @call(.auto, utils.matmul, args_2); - } -} diff --git a/src/lib.zig b/src/lib.zig new file mode 100644 index 0000000..f927977 --- /dev/null +++ b/src/lib.zig @@ -0,0 +1,7 @@ +const linear_algebra = @import("lib/linear_algebra.zig"); + +pub const dotProduct = linear_algebra.dotProduct; +pub const matmul = linear_algebra.matmul; +pub const matmul2 = linear_algebra.matmul2; +pub const matmul3 = linear_algebra.matmul3; +pub const rope = @import("lib/rope.zig").rope; diff --git a/src/lib/linear_algebra.zig b/src/lib/linear_algebra.zig new file mode 100644 index 0000000..1830db5 --- /dev/null +++ b/src/lib/linear_algebra.zig @@ -0,0 +1,82 @@ +const std = @import("std"); + +const max_vector_len: comptime_int = 16; +const min_vector_len: comptime_int = 4; + +pub fn dotProduct(a: []const f32, b: []const f32) f32 { + @setFloatMode(.Optimized); + + std.debug.assert(a.len == b.len); + + const rest_len = a.len % max_vector_len; + + std.debug.assert(rest_len % min_vector_len == 0); + + var buffer_1: @Vector(max_vector_len, f32) = @splat(0.0); + var index: usize = 0; + + while (index < a.len - rest_len) : (index += max_vector_len) { + buffer_1 += + @as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) * + @as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*); + } + + var result = @reduce(.Add, buffer_1); + + if (rest_len > 0) { + var buffer_2: @Vector(min_vector_len, f32) = @splat(0.0); + + index = a.len - rest_len; + + while (index < a.len) : (index += min_vector_len) { + buffer_2 += + @as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) * + @as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*); + } + + result += @reduce(.Add, buffer_2); + } + + return result; +} + +pub fn matmul(result: []f32, a: []const f32, b: []const f32) void { + std.debug.assert(b.len >= result.len * a.len); // TODO: enforce == instead of >= + + for (result, 0..) |*entry, i| { + entry.* = dotProduct(a, b[(i * a.len)..][0..a.len]); + } +} + +pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void { + const cpu_count = std.Thread.getCpuCount() catch 1; + + if (multi_threaded and cpu_count > 2) { + const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); + const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); + + thread_1.join(); + thread_2.join(); + } else { + @call(.auto, matmul, args_1); + @call(.auto, matmul, args_2); + } +} + +pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void { + const cpu_count = std.Thread.getCpuCount() catch 1; + + if (multi_threaded and cpu_count > 3) { + const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); + const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); + const thread_3 = try std.Thread.spawn(.{}, matmul, args_3); + + thread_1.join(); + thread_2.join(); + thread_3.join(); + } else { + @call(.auto, matmul, args_1); + @call(.auto, matmul, args_2); + @call(.auto, matmul, args_3); + } +} diff --git a/src/lib/rope.zig b/src/lib/rope.zig new file mode 100644 index 0000000..11acb66 --- /dev/null +++ b/src/lib/rope.zig @@ -0,0 +1,40 @@ +const std = @import("std"); + +// RoFormer: Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/abs/2104.09864) +pub fn rope( + pos: usize, + head_size: usize, + queries_buffer: []f32, + keys_buffer: []f32, +) void { + @setFloatMode(.Optimized); + + std.debug.assert(keys_buffer.len <= queries_buffer.len); + + var index: usize = 0; + + while (index < queries_buffer.len) : (index += 2) { + const head_index: f32 = @floatFromInt(index % head_size); + + const frequency: f32 = + 1 / std.math.pow(f32, 10000, head_index / @as(f32, @floatFromInt(head_size))); + + const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(pos)) * frequency; + const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor); + const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor); + + const query_0 = queries_buffer[index]; + const query_1 = queries_buffer[index + 1]; + + queries_buffer[index] = query_0 * real_rotation_value - query_1 * imag_rotation_value; + queries_buffer[index + 1] = query_0 * imag_rotation_value + query_1 * real_rotation_value; + + if (index < keys_buffer.len) { + const key_0 = keys_buffer[index]; + const key_1 = keys_buffer[index + 1]; + + keys_buffer[index] = key_0 * real_rotation_value - key_1 * imag_rotation_value; + keys_buffer[index + 1] = key_0 * imag_rotation_value + key_1 * real_rotation_value; + } + } +} diff --git a/src/main.zig b/src/main.zig index 2a0721a..7de889c 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,5 +1,6 @@ const std = @import("std"); +const Attention = @import("attention.zig").Attention; const checkpoint = @import("checkpoint.zig"); const cli = @import("cli.zig"); const FeedForward = @import("feed_forward.zig").FeedForward; @@ -67,16 +68,22 @@ pub fn main() !void { var total_decoding_time: i64 = 0; var total_sampling_time: i64 = 0; + var attention: Attention = undefined; + + try attention.init(allocator, &config); + defer attention.deinit(allocator); + var feed_forward: FeedForward = undefined; try feed_forward.init(allocator, &config); + defer feed_forward.deinit(allocator); // advance the state state machine for (0..args.n_steps) |pos| { start_time = std.time.milliTimestamp(); // forward the transformer to get logits for the next token - try transformer.decode(allocator, token, pos, config, &run_state, &weights, &feed_forward); + try transformer.decode(token, pos, config, &run_state, &weights, &attention, &feed_forward); if (pos == 0) { first_decoding_time = std.time.milliTimestamp() - start_time; diff --git a/src/transformer.zig b/src/transformer.zig index c24be92..01d80a0 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -1,19 +1,13 @@ const std = @import("std"); +const Attention = @import("attention.zig").Attention; const checkpoint = @import("checkpoint.zig"); const FeedForward = @import("feed_forward.zig").FeedForward; +const lib = @import("lib.zig"); const utils = @import("utils.zig"); pub const RunState = struct { hidden_state: []f32, - attention_input_buffer: []f32, - attention_output_buffer: []f32, - attention_scores: []f32, - query_buffer: []f32, - key_buffer: []f32, - value_buffer: []f32, - key_cache: []f32, - value_cache: []f32, logits: []f32, }; @@ -22,32 +16,19 @@ pub fn allocRunState( config: checkpoint.Config, run_state: *RunState, ) !void { - const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; - run_state.* = RunState{ .hidden_state = try allocator.alloc(f32, config.dim), - - .attention_input_buffer = try allocator.alloc(f32, config.dim), - .attention_output_buffer = try allocator.alloc(f32, config.dim), - - .attention_scores = try allocator.alloc(f32, config.n_heads * config.seq_len), - .query_buffer = try allocator.alloc(f32, config.dim), - .key_buffer = try allocator.alloc(f32, kv_dim), - .value_buffer = try allocator.alloc(f32, kv_dim), - .key_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim), - .value_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim), - .logits = try allocator.alloc(f32, config.vocab_size), }; } pub fn decode( - allocator: std.mem.Allocator, token: usize, pos: usize, config: checkpoint.Config, run_state: *RunState, weights: *const checkpoint.Weights, + attention: *Attention, feed_forward: *FeedForward, ) !void { @setFloatMode(.Optimized); @@ -58,138 +39,19 @@ pub fn decode( weights.token_embedding[(token * config.dim)..][0..run_state.hidden_state.len], ); - const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; - // integer multiplier of the kv sharing in multiquery - const kv_mul = config.n_heads / config.n_kv_heads; - const head_size = config.dim / config.n_heads; - const head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(head_size))); - // forward all the layers for (0..config.n_layers) |layer| { // attention rmsnorm utils.rmsnorm( - run_state.attention_input_buffer, + attention.input_buffer, run_state.hidden_state, weights.rms_attention_input[(layer * config.dim)..], ); - const dim_multithreading_threshold = 4096; - - var pool: std.Thread.Pool = undefined; - - // qkv matmuls for this position - if (config.dim >= dim_multithreading_threshold) { - try pool.init(std.Thread.Pool.Options{ - .allocator = allocator, - .n_jobs = @max(1, @min(3, std.Thread.getCpuCount() catch 1)), - }); - - defer pool.deinit(); - - try pool.spawn(utils.matmul, .{ - run_state.query_buffer, - run_state.attention_input_buffer, - weights.query[(layer * config.dim * config.dim)..], - }); - - try pool.spawn(utils.matmul, .{ - run_state.key_buffer, - run_state.attention_input_buffer, - weights.key[(layer * config.dim * kv_dim)..], - }); - - try pool.spawn(utils.matmul, .{ - run_state.value_buffer, - run_state.attention_input_buffer, - weights.value[(layer * config.dim * kv_dim)..], - }); - } else { - utils.matmul( - run_state.query_buffer, - run_state.attention_input_buffer, - weights.query[(layer * config.dim * config.dim)..], - ); - - utils.matmul( - run_state.key_buffer, - run_state.attention_input_buffer, - weights.key[(layer * config.dim * kv_dim)..], - ); - - utils.matmul( - run_state.value_buffer, - run_state.attention_input_buffer, - weights.value[(layer * config.dim * kv_dim)..], - ); - } - - rope(pos, head_size, kv_dim, &config, run_state); - - // save key,value at this time step (pos) to our kv cache - const loff = layer * config.seq_len * kv_dim; // kv cache layer offset for convenience - const key_cache_row = run_state.key_cache[(loff + pos * kv_dim)..]; - const value_cache_row = run_state.value_cache[(loff + pos * kv_dim)..]; - - @memcpy(key_cache_row[0..run_state.key_buffer.len], run_state.key_buffer); - @memcpy(value_cache_row[0..run_state.value_buffer.len], run_state.value_buffer); - - // multihead attention. iterate over all heads - for (0..config.n_heads) |head| { - // get the query vector for this head - const q = run_state.query_buffer[(head * head_size)..]; - // attention scores for this head - const att = run_state.attention_scores[(head * config.seq_len)..]; - - // iterate over all timesteps, including the current one - for (0..(pos + 1)) |t| { - // get the key vector for this head and at this timestep - const k = run_state.key_cache[(loff + t * kv_dim + (head / kv_mul) * head_size)..]; - - // calculate the attention score as the dot product of q and k - var score: f32 = 0; - - for (0..head_size) |i| { - score += q[i] * k[i]; - } - - score /= head_size_sqrt; - - // save the score to the attention buffer - att[t] = score; - } - - // softmax the scores to get attention weights, from 0..pos inclusively - utils.softmax(att[0..(pos + 1)]); - - // weighted sum of the values, store back into intermediate_buffer - const intermediate_buffer = - run_state.attention_input_buffer[(head * head_size)..][0..head_size]; - - @memset(intermediate_buffer, 0); - - for (0..(pos + 1)) |t| { - // get the value vector for this head and at this timestep - const v = run_state.value_cache[(loff + t * kv_dim + (head / kv_mul) * head_size)..]; - - // get the attention weight for this timestep - const a = att[t]; - - // accumulate the weighted value into intermediate_buffer - for (0..head_size) |i| { - intermediate_buffer[i] += a * v[i]; - } - } - } - - // final matmul to get the output of the attention - utils.matmul( - run_state.attention_output_buffer, - run_state.attention_input_buffer, - weights.attention_output[(layer * config.dim * config.dim)..], - ); + try attention.forward(&config, weights, pos, layer); // residual connection back into hidden_state - utils.accum(run_state.hidden_state, run_state.attention_output_buffer); + utils.accum(run_state.hidden_state, attention.output_buffer); // ffn rmsnorm utils.rmsnorm( @@ -208,44 +70,5 @@ pub fn decode( utils.rmsnorm(run_state.hidden_state, run_state.hidden_state, weights.rms_final); // classifier into logits - utils.matmul(run_state.logits, run_state.hidden_state, weights.classifier); -} - -fn rope( - pos: usize, - head_size: usize, - kv_dim: usize, - config: *const checkpoint.Config, - run_state: *RunState, -) void { - @setFloatMode(.Optimized); - - var i: usize = 0; - - // RoPE relative positional encoding: complex-valued rotate q and k in each head - // https://github.com/karpathy/llama2.c/issues/302#issue-1851956882 - // https://github.com/karpathy/llama2.c/commit/bd182289c596fa6059eb7b3b7c8ccd04b5c90fc3 - while (i < config.dim) : (i += 2) { - const head_dim: f32 = @floatFromInt(i % head_size); - const freq: f32 = 1 / std.math.pow(f32, 10000, head_dim / @as(f32, @floatFromInt(head_size))); - const value: f32 = @as(f32, @floatFromInt(pos)) * freq; - const fcr: f32 = std.math.cos(value); - const fci: f32 = std.math.sin(value); - - // rotate q - const q0 = run_state.query_buffer[i]; - const q1 = run_state.query_buffer[i + 1]; - - run_state.query_buffer[i] = q0 * fcr - q1 * fci; - run_state.query_buffer[i + 1] = q0 * fci + q1 * fcr; - - // rotate k - if (i < kv_dim) { - const k0 = run_state.key_buffer[i]; - const k1 = run_state.key_buffer[i + 1]; - - run_state.key_buffer[i] = k0 * fcr - k1 * fci; - run_state.key_buffer[i + 1] = k0 * fci + k1 * fcr; - } - } + lib.matmul(run_state.logits, run_state.hidden_state, weights.classifier); } diff --git a/src/utils.zig b/src/utils.zig index fef78c1..4d589de 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -106,54 +106,6 @@ fn lessThan(context: void, lhs: ProbIndex, rhs: ProbIndex) bool { return rhs.prob < lhs.prob; } -pub fn matmul(result: []f32, a: []const f32, b: []const f32) void { - std.debug.assert(b.len >= result.len * a.len); - - for (result, 0..) |*scalar, i| { - scalar.* = scalarProduct(a, b[(i * a.len)..][0..a.len]); - } -} - -inline fn scalarProduct(a: []const f32, b: []const f32) f32 { - @setFloatMode(.Optimized); - - const big_vector_len: comptime_int = 16; - const small_vector_len: comptime_int = 4; - - std.debug.assert(a.len == b.len); - - const rest_len = a.len % big_vector_len; - - std.debug.assert(rest_len % small_vector_len == 0); - - var big_accu: @Vector(big_vector_len, f32) = @splat(0.0); - var i: usize = 0; - - while (i < a.len - rest_len) : (i += big_vector_len) { - big_accu += - @as(@Vector(big_vector_len, f32), a[i..][0..big_vector_len].*) * - @as(@Vector(big_vector_len, f32), b[i..][0..big_vector_len].*); - } - - var scalar_product = @reduce(.Add, big_accu); - - if (rest_len > 0) { - var small_accu: @Vector(small_vector_len, f32) = @splat(0.0); - - i = a.len - rest_len; - - while (i < a.len) : (i += small_vector_len) { - small_accu += - @as(@Vector(small_vector_len, f32), a[i..][0..small_vector_len].*) * - @as(@Vector(small_vector_len, f32), b[i..][0..small_vector_len].*); - } - - scalar_product += @reduce(.Add, small_accu); - } - - return scalar_product; -} - pub fn softmax(x: []f32) void { @setFloatMode(.Optimized);