Skip to content

Commit

Permalink
Attention (WIP refactoring)
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 22, 2023
1 parent 92e2602 commit e8789a3
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 263 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ Some deviations from the original include:
- Utilization of slices instead of many-item pointers
- For models of 4096+ dimensions, thread pools are utilized to parallelize independent matrix
multiplications

## Papers

- Standard transformer architecture: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- Llama 1: [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
- Llama 2: [Llama 2: Open Foundation and Fine-Tuned Chat Models](https://arxiv.org/abs/2307.09288)
- Pre-normalization using RMSNorm: [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)
- SwiGLU activation function: [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)
- Rotary positional embeddings: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
- Grouped-query attention: [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245v1)
158 changes: 158 additions & 0 deletions src/attention.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const lib = @import("lib.zig");
const utils = @import("utils.zig");

pub const Attention = struct {
const Self = @This();

input_buffer: []f32,
output_buffer: []f32,
scores_buffer: []f32,
queries_buffer: []f32,
keys_buffer: []f32,
values_buffer: []f32,
key_cache: []f32,
value_cache: []f32,

pub fn init(self: *Self, allocator: std.mem.Allocator, config: *const checkpoint.Config) !void {
const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;

self.input_buffer = try allocator.alloc(f32, config.dim);
self.output_buffer = try allocator.alloc(f32, config.dim);
self.scores_buffer = try allocator.alloc(f32, config.n_heads * config.seq_len);
self.queries_buffer = try allocator.alloc(f32, config.dim);
self.keys_buffer = try allocator.alloc(f32, kv_dim);
self.values_buffer = try allocator.alloc(f32, kv_dim);
self.key_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim);
self.value_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim);
}

pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void {
allocator.free(self.input_buffer);
allocator.free(self.output_buffer);
allocator.free(self.scores_buffer);
allocator.free(self.queries_buffer);
allocator.free(self.keys_buffer);
allocator.free(self.values_buffer);
allocator.free(self.key_cache);
allocator.free(self.value_cache);
}

pub fn forward(
self: *const Self,
config: *const checkpoint.Config,
weights: *const checkpoint.Weights,
pos: usize,
layer: usize,
) !void {
@setFloatMode(.Optimized);

const dim = config.dim;
const n_heads = config.n_heads;
const seq_len = config.seq_len;
const kv_dim = self.keys_buffer.len;
const query_weights_dim = dim * dim;
const kv_weights_dim = dim * kv_dim;

try lib.matmul3(
.{
self.queries_buffer,
self.input_buffer,
weights.query[(layer * query_weights_dim)..][0..query_weights_dim],
},
.{
self.keys_buffer,
self.input_buffer,
weights.key[(layer * kv_weights_dim)..][0..kv_weights_dim],
},
.{
self.values_buffer,
self.input_buffer,
weights.value[(layer * kv_weights_dim)..][0..kv_weights_dim],
},
dim >= 4096,
);

const head_size = dim / n_heads;

lib.rope(pos, head_size, self.queries_buffer, self.keys_buffer);

const kv_cache_dim = seq_len * kv_dim;
const kv_cache_offset = layer * kv_cache_dim;

@memcpy(
self.key_cache[(kv_cache_offset + pos * kv_dim)..][0..self.keys_buffer.len],
self.keys_buffer,
);

@memcpy(
self.value_cache[(kv_cache_offset + pos * kv_dim)..][0..self.values_buffer.len],
self.values_buffer,
);

for (0..n_heads) |query_head| {
self.compute_attention(query_head, head_size, config, pos, kv_cache_offset, kv_dim);
}

lib.matmul(
self.output_buffer,
self.input_buffer,
weights.attention_output[(layer * dim * dim)..][0..(dim * dim)],
);
}

fn compute_attention(
self: *const Self,
query_head: usize,
head_size: usize,
config: *const checkpoint.Config,
current_position: usize,
kv_cache_offset: usize,
kv_dim: usize,
) void {
const n_groups = config.n_heads / config.n_kv_heads;
const head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(head_size)));
const query_head_offset = query_head * head_size;
const query_head_group = query_head / n_groups;
const key_value_head_offset = query_head_group * head_size;

// get the query vector for this head
const query = self.queries_buffer[query_head_offset..][0..head_size];

// attention scores for this head
const attention_weights = self.scores_buffer[(query_head * config.seq_len)..];

// iterate over all timesteps, including the current one
for (0..(current_position + 1)) |position| {
// get the key vector for this head and at this timestep
const key = self.key_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..][0..head_size];

// calculate the attention score as the dot product of q and k
// save the score to the attention buffer
attention_weights[position] = lib.dotProduct(query, key) / head_size_sqrt;
}

// softmax the scores to get attention weights, from 0..pos inclusively
utils.softmax(attention_weights[0..(current_position + 1)]);

// weighted sum of the values, store back into intermediate_buffer
const intermediate_buffer = self.input_buffer[query_head_offset..][0..head_size];

@memset(intermediate_buffer, 0);

for (0..(current_position + 1)) |position| {
// get the value vector for this head and at this timestep
const value = self.value_cache[(kv_cache_offset + position * kv_dim + key_value_head_offset)..];

// get the attention weight for this timestep
const attention_weight = attention_weights[position];

// accumulate the weighted value into intermediate_buffer
for (0..head_size) |i| {
intermediate_buffer[i] += attention_weight * value[i];
}
}
}
};
2 changes: 2 additions & 0 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ pub fn readFile(
weights.* = Weights{
.token_embedding = token_embedding,
.rms_attention_input = readFloatSlice(&weights_data, config.n_layers * config.dim),

.query = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_heads * head_size)),
.key = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_kv_heads * head_size)),
.value = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_kv_heads * head_size)),
.attention_output = readFloatSlice(&weights_data, config.n_layers * (config.n_heads * head_size) * config.dim),

.rms_ffn_input = readFloatSlice(&weights_data, config.n_layers * config.dim),
.ffn_input_to_hidden = readFloatSlice(&weights_data, config.n_layers * config.dim * config.hidden_dim),
.ffn_hidden_to_output = readFloatSlice(&weights_data, config.n_layers * config.hidden_dim * config.dim),
Expand Down
40 changes: 10 additions & 30 deletions src/feed_forward.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const lib = @import("lib.zig");
const utils = @import("utils.zig");

pub const FeedForward = struct {
Expand Down Expand Up @@ -32,56 +33,35 @@ pub const FeedForward = struct {
) !void {
@setFloatMode(.Optimized);

const input_buffer = self.input_buffer;
const hidden_buffer = self.hidden_buffer;
const residual_buffer = self.residual_buffer;
const output_buffer = self.output_buffer;
const dim = self.input_buffer.len;
const hidden_dim = self.hidden_buffer.len;

std.debug.assert(input_buffer.len == output_buffer.len);
std.debug.assert(hidden_buffer.len == residual_buffer.len);

const dim = input_buffer.len;
const hidden_dim = hidden_buffer.len;
const weights_size = dim * hidden_dim;
const weights_offset = layer * weights_size;

const input_to_hidden = weights.ffn_input_to_hidden[weights_offset..][0..weights_size];
const input_to_residual = weights.ffn_input_to_residual[weights_offset..][0..weights_size];
const hidden_to_output = weights.ffn_hidden_to_output[weights_offset..][0..weights_size];

try matmul2(
.{ hidden_buffer, input_buffer, input_to_hidden },
.{ residual_buffer, input_buffer, input_to_residual },
try lib.matmul2(
.{ self.hidden_buffer, self.input_buffer, input_to_hidden },
.{ self.residual_buffer, self.input_buffer, input_to_residual },
dim >= 4096,
);

for (0..hidden_dim) |i| {
hidden_buffer[i] = silu(hidden_buffer[i]) * residual_buffer[i];
self.hidden_buffer[i] = silu(self.hidden_buffer[i]) * self.residual_buffer[i];
}

utils.matmul(output_buffer, hidden_buffer, hidden_to_output);
lib.matmul(self.output_buffer, self.hidden_buffer, hidden_to_output);
}
};

// https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
// GLU Variants Improve Transformer (https://arxiv.org/abs/2002.05202)
inline fn silu(x: f32) f32 {
return x * sigmoid(x);
}

inline fn sigmoid(x: f32) f32 {
return 1 / (1 + @exp(-x));
}

fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 2) {
const thread_1 = try std.Thread.spawn(.{}, utils.matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, utils.matmul, args_2);

thread_1.join();
thread_2.join();
} else {
@call(.auto, utils.matmul, args_1);
@call(.auto, utils.matmul, args_2);
}
}
7 changes: 7 additions & 0 deletions src/lib.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const linear_algebra = @import("lib/linear_algebra.zig");

pub const dotProduct = linear_algebra.dotProduct;
pub const matmul = linear_algebra.matmul;
pub const matmul2 = linear_algebra.matmul2;
pub const matmul3 = linear_algebra.matmul3;
pub const rope = @import("lib/rope.zig").rope;
82 changes: 82 additions & 0 deletions src/lib/linear_algebra.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
const std = @import("std");

const max_vector_len: comptime_int = 16;
const min_vector_len: comptime_int = 4;

pub fn dotProduct(a: []const f32, b: []const f32) f32 {
@setFloatMode(.Optimized);

std.debug.assert(a.len == b.len);

const rest_len = a.len % max_vector_len;

std.debug.assert(rest_len % min_vector_len == 0);

var buffer_1: @Vector(max_vector_len, f32) = @splat(0.0);
var index: usize = 0;

while (index < a.len - rest_len) : (index += max_vector_len) {
buffer_1 +=
@as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) *
@as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*);
}

var result = @reduce(.Add, buffer_1);

if (rest_len > 0) {
var buffer_2: @Vector(min_vector_len, f32) = @splat(0.0);

index = a.len - rest_len;

while (index < a.len) : (index += min_vector_len) {
buffer_2 +=
@as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) *
@as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*);
}

result += @reduce(.Add, buffer_2);
}

return result;
}

pub fn matmul(result: []f32, a: []const f32, b: []const f32) void {
std.debug.assert(b.len >= result.len * a.len); // TODO: enforce == instead of >=

for (result, 0..) |*entry, i| {
entry.* = dotProduct(a, b[(i * a.len)..][0..a.len]);
}
}

pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 2) {
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2);

thread_1.join();
thread_2.join();
} else {
@call(.auto, matmul, args_1);
@call(.auto, matmul, args_2);
}
}

pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 3) {
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2);
const thread_3 = try std.Thread.spawn(.{}, matmul, args_3);

thread_1.join();
thread_2.join();
thread_3.join();
} else {
@call(.auto, matmul, args_1);
@call(.auto, matmul, args_2);
@call(.auto, matmul, args_3);
}
}
40 changes: 40 additions & 0 deletions src/lib/rope.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
const std = @import("std");

// RoFormer: Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/abs/2104.09864)
pub fn rope(
pos: usize,
head_size: usize,
queries_buffer: []f32,
keys_buffer: []f32,
) void {
@setFloatMode(.Optimized);

std.debug.assert(keys_buffer.len <= queries_buffer.len);

var index: usize = 0;

while (index < queries_buffer.len) : (index += 2) {
const head_index: f32 = @floatFromInt(index % head_size);

const frequency: f32 =
1 / std.math.pow(f32, 10000, head_index / @as(f32, @floatFromInt(head_size)));

const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(pos)) * frequency;
const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor);
const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor);

const query_0 = queries_buffer[index];
const query_1 = queries_buffer[index + 1];

queries_buffer[index] = query_0 * real_rotation_value - query_1 * imag_rotation_value;
queries_buffer[index + 1] = query_0 * imag_rotation_value + query_1 * real_rotation_value;

if (index < keys_buffer.len) {
const key_0 = keys_buffer[index];
const key_1 = keys_buffer[index + 1];

keys_buffer[index] = key_0 * real_rotation_value - key_1 * imag_rotation_value;
keys_buffer[index + 1] = key_0 * imag_rotation_value + key_1 * real_rotation_value;
}
}
}
Loading

0 comments on commit e8789a3

Please sign in to comment.