Skip to content

Commit

Permalink
Introduce Tensor struct
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 11, 2023
1 parent 162f1b8 commit 5395f45
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 369 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ zig build -Doptimize=ReleaseFast run -- stories260K.bin -z tok512.bin -i "Once u
- Llama 2: [Llama 2: Open Foundation and Fine-Tuned Chat Models](https://arxiv.org/abs/2307.09288)
- Pre-normalization using RMSNorm: [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)
- SwiGLU activation function: [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)
- Swish activation function: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
- Rotary positional embeddings: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
- Grouped-query attention: [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245v1)
- Nucleus sampling: [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)
Expand Down
220 changes: 75 additions & 145 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,145 +2,106 @@ const Self = @This();

const std = @import("std");
const Checkpoint = @import("checkpoint.zig");
const Matrix = @import("./matrix.zig");
const Tensor = @import("./tensor.zig").Tensor;
const vector = @import("./vector.zig");

allocator: std.mem.Allocator,

n_heads: usize,
n_query_groups: usize,
checkpoint: Checkpoint,
head_size: usize,
head_size_sqrt: f32,
sequence_length: usize,

query_projection_matrices: []const Matrix,
key_projection_matrices: []const Matrix,
value_projection_matrices: []const Matrix,
output_projection_matrices: []const Matrix,

input_vector: []f32,
output_vector: []f32,

multi_head_query: []f32,
query_vectors: []const []f32,
key_cache: []f32,
value_cache: []f32,
input_buffer: Tensor(2),
output_buffer: Tensor(1),
query_buffer: Tensor(2),
key_cache: Tensor(4),
value_cache: Tensor(4),
scores: []f32,

pub fn init(
allocator: std.mem.Allocator,
checkpoint: *const Checkpoint,
sequence_length: usize,
) !Self {
const embedding_size = checkpoint.embedding_size;
const n_layers = checkpoint.n_layers;
const n_heads = checkpoint.n_heads;
const n_query_groups = checkpoint.n_query_groups;
const weights = checkpoint.weights;

const input_vector = try allocator.alloc(f32, embedding_size);

errdefer allocator.free(input_vector);

const output_vector = try allocator.alloc(f32, embedding_size);
pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_length: usize) !Self {
const head_size: usize = checkpoint.embedding_size / checkpoint.n_heads;
const input_buffer = try Tensor(2).init(allocator, [_]usize{ checkpoint.n_heads, head_size });

errdefer allocator.free(output_vector);
errdefer input_buffer.deinit();

const head_size: usize = embedding_size / n_heads;
const multi_head_query = try allocator.alloc(f32, n_heads * head_size);
const output_buffer = try Tensor(1).init(allocator, [_]usize{checkpoint.embedding_size});

errdefer allocator.free(multi_head_query);
errdefer output_buffer.deinit();

const query_vectors = try vector.slice([]f32, allocator, head_size, multi_head_query);
const query_buffer = try Tensor(2).init(allocator, [_]usize{ checkpoint.n_heads, head_size });

errdefer allocator.free(query_vectors);
errdefer query_buffer.deinit();

const key_value_cache_size = n_layers * sequence_length * n_query_groups * head_size;
const key_cache = try allocator.alloc(f32, key_value_cache_size);
const key_cache = try Tensor(4).init(
allocator,
[_]usize{ checkpoint.n_layers, sequence_length, checkpoint.n_query_groups, head_size },
);

errdefer allocator.free(key_cache);
errdefer key_cache.deinit();

const value_cache = try allocator.alloc(f32, key_value_cache_size);
const value_cache = try Tensor(4).init(
allocator,
[_]usize{ checkpoint.n_layers, sequence_length, checkpoint.n_query_groups, head_size },
);

errdefer allocator.free(value_cache);
errdefer value_cache.deinit();

const scores = try allocator.alloc(f32, sequence_length);

errdefer allocator.free(scores);

return Self{
.allocator = allocator,

.n_heads = n_heads,
.n_query_groups = n_query_groups,
.checkpoint = checkpoint,
.head_size = head_size,
.head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(head_size))),
.sequence_length = sequence_length,

.query_projection_matrices = weights.attention_query_projection_matrices,
.key_projection_matrices = weights.attention_key_projection_matrices,
.value_projection_matrices = weights.attention_value_projection_matrices,
.output_projection_matrices = weights.attention_output_projection_matrices,

.input_vector = input_vector,
.output_vector = output_vector,

.multi_head_query = multi_head_query,
.query_vectors = query_vectors,
.input_buffer = input_buffer,
.output_buffer = output_buffer,
.query_buffer = query_buffer,
.key_cache = key_cache,
.value_cache = value_cache,
.scores = scores,
};
}

pub fn deinit(self: *const Self) void {
self.allocator.free(self.input_vector);
self.allocator.free(self.output_vector);
self.allocator.free(self.multi_head_query);
self.allocator.free(self.query_vectors);
self.allocator.free(self.key_cache);
self.allocator.free(self.value_cache);
self.input_buffer.deinit();
self.output_buffer.deinit();
self.query_buffer.deinit();
self.key_cache.deinit();
self.value_cache.deinit();
self.allocator.free(self.scores);
}

pub fn forward(self: *const Self, layer: usize, position: usize) !void {
const query_projection_matrix = self.query_projection_matrices[layer];
const key_projection_matrix = self.key_projection_matrices[layer];
const value_projection_matrix = self.value_projection_matrices[layer];
const output_projection_matrix = self.output_projection_matrices[layer];

const multi_head_query = self.multi_head_query;
const multi_head_key = self.sliceCache(.key, layer, position, null);
const multi_head_value = self.sliceCache(.value, layer, position, null);

query_projection_matrix.multiplyVector(self.input_vector, multi_head_query);
key_projection_matrix.multiplyVector(self.input_vector, multi_head_key);
value_projection_matrix.multiplyVector(self.input_vector, multi_head_value);

self.applyRotaryPositionEmbedding(position, multi_head_key);

for (0..self.n_heads) |head| {
self.computeGroupedQueryAttention(layer, position, head);
const weights = self.checkpoint.weights;
const query_matrix = weights.attention_query_matrices.slice(layer);
const key_matrix = weights.attention_key_matrices.slice(layer);
const value_matrix = weights.attention_value_matrices.slice(layer);
const output_matrix = weights.attention_output_matrices.slice(layer);
const key_vectors = self.key_cache.slice(layer).slice(position);
const value_vectors = self.value_cache.slice(layer).slice(position);

query_matrix.multiplyVector(self.input_buffer.data, self.query_buffer.data);
key_matrix.multiplyVector(self.input_buffer.data, key_vectors.data);
value_matrix.multiplyVector(self.input_buffer.data, value_vectors.data);

self.rope(position, key_vectors);

for (0..self.checkpoint.n_heads) |head| {
self.gqa(layer, position, head);
}

output_projection_matrix.multiplyVector(self.input_vector, self.output_vector);
output_matrix.multiplyVector(self.input_buffer.data, self.output_buffer.data);
}

// https://arxiv.org/abs/2104.09864
fn applyRotaryPositionEmbedding(
self: *const Self,
position: usize,
multi_head_key: []f32,
) void {
// Rotary positional embeddings: https://arxiv.org/abs/2104.09864
fn rope(self: *const Self, position: usize, key_vectors: Tensor(2)) void {
@setFloatMode(.Optimized);

const multi_head_query = self.multi_head_query;

std.debug.assert(multi_head_query.len % multi_head_key.len == 0);
std.debug.assert(self.query_buffer.data.len % key_vectors.data.len == 0);

var index: usize = 0;

while (index < multi_head_query.len) : (index += 2) {
while (index < self.query_buffer.data.len) : (index += 2) {
const head: f32 = @floatFromInt(index % self.head_size);

const frequency =
Expand All @@ -150,80 +111,49 @@ fn applyRotaryPositionEmbedding(
const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor);
const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor);

const q_0 = multi_head_query[index];
const q_1 = multi_head_query[index + 1];
const q_0 = self.query_buffer.data[index];
const q_1 = self.query_buffer.data[index + 1];

multi_head_query[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value;
multi_head_query[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value;
self.query_buffer.data[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value;
self.query_buffer.data[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value;

if (index < multi_head_key.len) {
const k_0 = multi_head_key[index];
const k_1 = multi_head_key[index + 1];
if (index < key_vectors.data.len) {
const k_0 = key_vectors.data[index];
const k_1 = key_vectors.data[index + 1];

multi_head_key[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value;
multi_head_key[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value;
key_vectors.data[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value;
key_vectors.data[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value;
}
}
}

// https://arxiv.org/abs/1706.03762
fn computeGroupedQueryAttention(
self: *const Self,
layer: usize,
current_position: usize,
head: usize,
) void {
// Grouped-query attention: https://arxiv.org/abs/2305.13245v1
fn gqa(self: *const Self, layer: usize, current_position: usize, head: usize) void {
@setFloatMode(.Optimized);

const query_vector = self.query_vectors[head];
const query_group = head / (self.n_heads / self.n_query_groups);
const query_vector = self.query_buffer.slice(head);
const query_group = head / (self.checkpoint.n_heads / self.checkpoint.n_query_groups);
const next_position = current_position + 1;

for (0..next_position) |position| {
const key_vector = self.sliceCache(.key, layer, position, query_group);
const key_vector = self.key_cache.slice(layer).slice(position).slice(query_group);

self.scores[position] = vector.dot(query_vector, key_vector) / self.head_size_sqrt;
self.scores[position] =
vector.dot(query_vector.data, key_vector.data) / self.head_size_sqrt;
}

vector.softmax(self.scores[0..next_position]);

const attention_values = self.input_vector[(head * self.head_size)..][0..self.head_size];
const attention_buffer = self.input_buffer.slice(head);

@memset(attention_values, 0);
@memset(attention_buffer.data, 0);

for (0..next_position) |position| {
const value_vector = self.sliceCache(.value, layer, position, query_group);

const value_vector = self.value_cache.slice(layer).slice(position).slice(query_group);
const weight = self.scores[position];

for (0..self.head_size) |index| {
attention_values[index] += value_vector[index] * weight;
attention_buffer.data[index] += value_vector.data[index] * weight;
}
}
}

const CacheType = enum { key, value };

fn sliceCache(
self: *const Self,
cache_type: CacheType,
layer: usize,
position: usize,
query_group: ?usize,
) []f32 {
const cache = if (cache_type == .key) self.key_cache else self.value_cache;
const multi_head_cache_size = self.n_query_groups * self.head_size;

const layer_cache_size = self.sequence_length * multi_head_cache_size;
const layer_cache_offset = layer * layer_cache_size;
const layer_cache = cache[layer_cache_offset..][0..layer_cache_size];

const multi_head_cache_offset = position * multi_head_cache_size;
const multi_head_cache = layer_cache[multi_head_cache_offset..][0..multi_head_cache_size];

if (query_group) |group| {
return multi_head_cache[(group * self.head_size)..][0..self.head_size];
}

return multi_head_cache;
}
2 changes: 1 addition & 1 deletion src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.logits);
next_token = self.sampler.sample(self.transformer.logits_buffer.data);
}

if (next_token == eos_token) {
Expand Down
Loading

0 comments on commit 5395f45

Please sign in to comment.