-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
324 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
const std = @import("std"); | ||
|
||
const checkpoint = @import("checkpoint.zig"); | ||
const lib = @import("lib.zig"); | ||
const utils = @import("utils.zig"); | ||
|
||
pub const Attention = struct { | ||
const Self = @This(); | ||
|
||
input_buffer: []f32, | ||
output_buffer: []f32, | ||
scores_buffer: []f32, | ||
queries_buffer: []f32, | ||
keys_buffer: []f32, | ||
values_buffer: []f32, | ||
key_cache: []f32, | ||
value_cache: []f32, | ||
|
||
pub fn init(self: *Self, allocator: std.mem.Allocator, config: *const checkpoint.Config) !void { | ||
const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; | ||
|
||
self.input_buffer = try allocator.alloc(f32, config.dim); | ||
self.output_buffer = try allocator.alloc(f32, config.dim); | ||
self.scores_buffer = try allocator.alloc(f32, config.n_heads * config.seq_len); | ||
self.queries_buffer = try allocator.alloc(f32, config.dim); | ||
self.keys_buffer = try allocator.alloc(f32, kv_dim); | ||
self.values_buffer = try allocator.alloc(f32, kv_dim); | ||
self.key_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim); | ||
self.value_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim); | ||
} | ||
|
||
pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void { | ||
allocator.free(self.input_buffer); | ||
allocator.free(self.output_buffer); | ||
allocator.free(self.scores_buffer); | ||
allocator.free(self.queries_buffer); | ||
allocator.free(self.keys_buffer); | ||
allocator.free(self.values_buffer); | ||
allocator.free(self.key_cache); | ||
allocator.free(self.value_cache); | ||
} | ||
|
||
pub fn forward( | ||
self: *const Self, | ||
config: *const checkpoint.Config, | ||
weights: *const checkpoint.Weights, | ||
pos: usize, | ||
layer: usize, | ||
) !void { | ||
@setFloatMode(.Optimized); | ||
|
||
const dim = config.dim; | ||
const n_heads = config.n_heads; | ||
const seq_len = config.seq_len; | ||
const kv_dim = self.keys_buffer.len; | ||
const query_weights_dim = dim * dim; | ||
const kv_weights_dim = dim * kv_dim; | ||
|
||
try lib.matmul3( | ||
.{ | ||
self.queries_buffer, | ||
self.input_buffer, | ||
weights.query[(layer * query_weights_dim)..][0..query_weights_dim], | ||
}, | ||
.{ | ||
self.keys_buffer, | ||
self.input_buffer, | ||
weights.key[(layer * kv_weights_dim)..][0..kv_weights_dim], | ||
}, | ||
.{ | ||
self.values_buffer, | ||
self.input_buffer, | ||
weights.value[(layer * kv_weights_dim)..][0..kv_weights_dim], | ||
}, | ||
dim >= 4096, | ||
); | ||
|
||
const head_size = dim / n_heads; | ||
|
||
lib.rope(pos, head_size, self.queries_buffer, self.keys_buffer); | ||
|
||
const kv_cache_dim = seq_len * kv_dim; | ||
const kv_cache_offset = layer * kv_cache_dim; | ||
|
||
@memcpy( | ||
self.key_cache[(kv_cache_offset + pos * kv_dim)..][0..self.keys_buffer.len], | ||
self.keys_buffer, | ||
); | ||
|
||
@memcpy( | ||
self.value_cache[(kv_cache_offset + pos * kv_dim)..][0..self.values_buffer.len], | ||
self.values_buffer, | ||
); | ||
|
||
for (0..n_heads) |query_head| { | ||
self.compute_attention(query_head, head_size, config, pos, kv_cache_offset, kv_dim); | ||
} | ||
|
||
lib.matmul( | ||
self.output_buffer, | ||
self.input_buffer, | ||
weights.attention_output[(layer * dim * dim)..][0..(dim * dim)], | ||
); | ||
} | ||
|
||
fn compute_attention( | ||
self: *const Self, | ||
query_head: usize, | ||
head_size: usize, | ||
config: *const checkpoint.Config, | ||
current_position: usize, | ||
kv_cache_offset: usize, | ||
kv_dim: usize, | ||
) void { | ||
const n_groups = config.n_heads / config.n_kv_heads; | ||
const head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(head_size))); | ||
const query_head_offset = query_head * head_size; | ||
const query_head_group = query_head / n_groups; | ||
const key_value_head_offset = query_head_group * head_size; | ||
|
||
// get the query vector for this head | ||
const query = self.queries_buffer[query_head_offset..][0..head_size]; | ||
|
||
// attention scores for this head | ||
const attention_weights = self.scores_buffer[(query_head * config.seq_len)..]; | ||
|
||
// iterate over all timesteps, including the current one | ||
for (0..(current_position + 1)) |position| { | ||
// get the key vector for this head and at this timestep | ||
const key = self.key_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..][0..head_size]; | ||
|
||
// calculate the attention score as the dot product of q and k | ||
// save the score to the attention buffer | ||
attention_weights[position] = lib.dotProduct(query, key) / head_size_sqrt; | ||
} | ||
|
||
// softmax the scores to get attention weights, from 0..pos inclusively | ||
utils.softmax(attention_weights[0..(current_position + 1)]); | ||
|
||
// weighted sum of the values, store back into intermediate_buffer | ||
const intermediate_buffer = self.input_buffer[query_head_offset..][0..head_size]; | ||
|
||
@memset(intermediate_buffer, 0); | ||
|
||
for (0..(current_position + 1)) |position| { | ||
// get the value vector for this head and at this timestep | ||
const value = self.value_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..]; | ||
|
||
// get the attention weight for this timestep | ||
const attention_weight = attention_weights[position]; | ||
|
||
// accumulate the weighted value into intermediate_buffer | ||
for (0..head_size) |i| { | ||
intermediate_buffer[i] += attention_weight * value[i]; | ||
} | ||
} | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
const linear_algebra = @import("lib/linear_algebra.zig"); | ||
|
||
pub const dotProduct = linear_algebra.dotProduct; | ||
pub const matmul = linear_algebra.matmul; | ||
pub const matmul2 = linear_algebra.matmul2; | ||
pub const matmul3 = linear_algebra.matmul3; | ||
pub const rope = @import("lib/rope.zig").rope; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
const std = @import("std"); | ||
|
||
const max_vector_len: comptime_int = 16; | ||
const min_vector_len: comptime_int = 4; | ||
|
||
pub fn dotProduct(a: []const f32, b: []const f32) f32 { | ||
@setFloatMode(.Optimized); | ||
|
||
std.debug.assert(a.len == b.len); | ||
|
||
const rest_len = a.len % max_vector_len; | ||
|
||
std.debug.assert(rest_len % min_vector_len == 0); | ||
|
||
var buffer_1: @Vector(max_vector_len, f32) = @splat(0.0); | ||
var index: usize = 0; | ||
|
||
while (index < a.len - rest_len) : (index += max_vector_len) { | ||
buffer_1 += | ||
@as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) * | ||
@as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*); | ||
} | ||
|
||
var result = @reduce(.Add, buffer_1); | ||
|
||
if (rest_len > 0) { | ||
var buffer_2: @Vector(min_vector_len, f32) = @splat(0.0); | ||
|
||
index = a.len - rest_len; | ||
|
||
while (index < a.len) : (index += min_vector_len) { | ||
buffer_2 += | ||
@as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) * | ||
@as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*); | ||
} | ||
|
||
result += @reduce(.Add, buffer_2); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
pub fn matmul(result: []f32, a: []const f32, b: []const f32) void { | ||
std.debug.assert(b.len >= result.len * a.len); // TODO: enforce == instead of >= | ||
|
||
for (result, 0..) |*entry, i| { | ||
entry.* = dotProduct(a, b[(i * a.len)..][0..a.len]); | ||
} | ||
} | ||
|
||
pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void { | ||
const cpu_count = std.Thread.getCpuCount() catch 1; | ||
|
||
if (multi_threaded and cpu_count > 2) { | ||
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); | ||
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); | ||
|
||
thread_1.join(); | ||
thread_2.join(); | ||
} else { | ||
@call(.auto, matmul, args_1); | ||
@call(.auto, matmul, args_2); | ||
} | ||
} | ||
|
||
pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void { | ||
const cpu_count = std.Thread.getCpuCount() catch 1; | ||
|
||
if (multi_threaded and cpu_count > 3) { | ||
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); | ||
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); | ||
const thread_3 = try std.Thread.spawn(.{}, matmul, args_3); | ||
|
||
thread_1.join(); | ||
thread_2.join(); | ||
thread_3.join(); | ||
} else { | ||
@call(.auto, matmul, args_1); | ||
@call(.auto, matmul, args_2); | ||
@call(.auto, matmul, args_3); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
const std = @import("std"); | ||
|
||
// RoFormer: Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/abs/2104.09864) | ||
pub fn rope( | ||
pos: usize, | ||
head_size: usize, | ||
queries_buffer: []f32, | ||
keys_buffer: []f32, | ||
) void { | ||
@setFloatMode(.Optimized); | ||
|
||
std.debug.assert(keys_buffer.len <= queries_buffer.len); | ||
|
||
var index: usize = 0; | ||
|
||
while (index < queries_buffer.len) : (index += 2) { | ||
const head_index: f32 = @floatFromInt(index % head_size); | ||
|
||
const frequency: f32 = | ||
1 / std.math.pow(f32, 10000, head_index / @as(f32, @floatFromInt(head_size))); | ||
|
||
const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(pos)) * frequency; | ||
const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor); | ||
const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor); | ||
|
||
const query_0 = queries_buffer[index]; | ||
const query_1 = queries_buffer[index + 1]; | ||
|
||
queries_buffer[index] = query_0 * real_rotation_value - query_1 * imag_rotation_value; | ||
queries_buffer[index + 1] = query_0 * imag_rotation_value + query_1 * real_rotation_value; | ||
|
||
if (index < keys_buffer.len) { | ||
const key_0 = keys_buffer[index]; | ||
const key_1 = keys_buffer[index + 1]; | ||
|
||
keys_buffer[index] = key_0 * real_rotation_value - key_1 * imag_rotation_value; | ||
keys_buffer[index + 1] = key_0 * imag_rotation_value + key_1 * real_rotation_value; | ||
} | ||
} | ||
} |
Oops, something went wrong.