Skip to content

Commit

Permalink
Use max_sequence_length as default for n_steps and small refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Sep 5, 2023
1 parent d7fdc5c commit e8de3a7
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 268 deletions.
115 changes: 44 additions & 71 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,51 @@ const Self = @This();
const std = @import("std");
const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");
const matrix = @import("matrix.zig");

allocator: std.mem.Allocator,
checkpoint: Checkpoint,
seq_len: usize,
sequence_length: usize,
input_buffer: []f32,
output_buffer: []f32,
scores_buffer: []f32,
queries_buffer: []f32,
keys_buffer: []f32,
values_buffer: []f32,
key_cache: []f32,
value_cache: []f32,

pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, seq_len: usize) !Self {
const dim = checkpoint.dim;
const kv_dim = checkpoint.kv_dim;
const kv_cache_dim = checkpoint.n_layers * seq_len * kv_dim;

const input_buffer = try allocator.alloc(f32, dim);
pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_length: usize) !Self {
const embedding_size = checkpoint.embedding_size;
const input_buffer = try allocator.alloc(f32, embedding_size);

errdefer allocator.free(input_buffer);

const output_buffer = try allocator.alloc(f32, dim);
const output_buffer = try allocator.alloc(f32, embedding_size);

errdefer allocator.free(output_buffer);

const scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len);
const scores_buffer = try allocator.alloc(f32, checkpoint.n_query_heads * sequence_length);

errdefer allocator.free(scores_buffer);

const queries_buffer = try allocator.alloc(f32, dim);
const queries_buffer = try allocator.alloc(f32, embedding_size);

errdefer allocator.free(queries_buffer);

const keys_buffer = try allocator.alloc(f32, kv_dim);

errdefer allocator.free(keys_buffer);

const values_buffer = try allocator.alloc(f32, kv_dim);

errdefer allocator.free(values_buffer);

const key_cache = try allocator.alloc(f32, kv_cache_dim);
const key_value_size = checkpoint.n_query_head_groups * checkpoint.query_head_size;
const key_value_cache_size = checkpoint.n_layers * sequence_length * key_value_size;
const key_cache = try allocator.alloc(f32, key_value_cache_size);

errdefer allocator.free(key_cache);

const value_cache = try allocator.alloc(f32, kv_cache_dim);
const value_cache = try allocator.alloc(f32, key_value_cache_size);

return Self{
.allocator = allocator,
.checkpoint = checkpoint,
.seq_len = seq_len,
.sequence_length = sequence_length,
.input_buffer = input_buffer,
.output_buffer = output_buffer,
.scores_buffer = scores_buffer,
.queries_buffer = queries_buffer,
.keys_buffer = keys_buffer,
.values_buffer = values_buffer,
.key_cache = key_cache,
.value_cache = value_cache,
};
Expand All @@ -72,55 +58,41 @@ pub fn deinit(self: *const Self) void {
self.allocator.free(self.output_buffer);
self.allocator.free(self.scores_buffer);
self.allocator.free(self.queries_buffer);
self.allocator.free(self.keys_buffer);
self.allocator.free(self.values_buffer);
self.allocator.free(self.key_cache);
self.allocator.free(self.value_cache);
}

pub fn forward(self: *const Self, pos: usize, layer: usize) !void {
const checkpoint = self.checkpoint;
const kv_dim = checkpoint.kv_dim;
const weights = checkpoint.weights;

try weights.attention_queries_matrix.multiplyVector(
try weights.attention_query_matrices.multiplyVector(
layer,
self.input_buffer,
self.queries_buffer,
);

try weights.attention_keys_matrix.multiplyVector(
layer,
self.input_buffer,
self.keys_buffer,
);
const query_head_size = checkpoint.query_head_size;
const key_value_size = checkpoint.n_query_head_groups * query_head_size;
const key_value_cache_offset = layer * (self.sequence_length * key_value_size);

try weights.attention_values_matrix.multiplyVector(
layer,
self.input_buffer,
self.values_buffer,
);
const key_cache = self.key_cache[key_value_cache_offset..];
const keys_buffer = key_cache[(pos * key_value_size)..][0..key_value_size];

lib.rope(pos, checkpoint.head_size, self.queries_buffer, self.keys_buffer);
const value_cache = self.value_cache[key_value_cache_offset..];
const values_buffer = value_cache[(pos * key_value_size)..][0..key_value_size];

const kv_cache_dim = self.seq_len * kv_dim;
const kv_cache_layer_offset = layer * kv_cache_dim;
try weights.attention_key_matrices.multiplyVector(layer, self.input_buffer, keys_buffer);

@memcpy(
self.key_cache[(kv_cache_layer_offset + pos * kv_dim)..][0..self.keys_buffer.len],
self.keys_buffer,
);
lib.rope(pos, query_head_size, self.queries_buffer, keys_buffer);

@memcpy(
self.value_cache[(kv_cache_layer_offset + pos * kv_dim)..][0..self.values_buffer.len],
self.values_buffer,
);
try weights.attention_value_matrices.multiplyVector(layer, self.input_buffer, values_buffer);

for (0..checkpoint.n_heads) |head| {
self.compute_weighted_values(pos, head, kv_cache_layer_offset);
for (0..checkpoint.n_query_heads) |query_head| {
self.compute_weighted_values(pos, query_head, key_cache, value_cache);
}

try weights.attention_output_matrix.multiplyVector(
try weights.attention_output_matrices.multiplyVector(
layer,
self.input_buffer,
self.output_buffer,
Expand All @@ -130,40 +102,41 @@ pub fn forward(self: *const Self, pos: usize, layer: usize) !void {
fn compute_weighted_values(
self: *const Self,
pos: usize,
head: usize,
kv_cache_layer_offset: usize,
query_head: usize,
key_cache: []const f32,
value_cache: []const f32,
) void {
@setFloatMode(.Optimized);

const checkpoint = self.checkpoint;
const kv_dim = checkpoint.kv_dim;
const head_size = checkpoint.head_size;

const group = head / checkpoint.n_groups;
const kv_head_offset = group * head_size;
const head_offset = head * head_size;
const query = self.queries_buffer[head_offset..][0..head_size];
const scores = self.scores_buffer[(head * self.seq_len)..];
const n_query_head_groups = checkpoint.n_query_head_groups;
const query_head_group = query_head / (checkpoint.n_query_heads / n_query_head_groups);
const query_head_size = checkpoint.query_head_size;
const query_head_offset = query_head * query_head_size;
const query = self.queries_buffer[query_head_offset..][0..query_head_size];
const key_value_size = n_query_head_groups * query_head_size;
const key_value_head_offset = query_head_group * query_head_size;
const scores = self.scores_buffer[(query_head * self.sequence_length)..];

for (0..(pos + 1)) |prev_pos| {
const kv_cache_head_offset = kv_cache_layer_offset + prev_pos * kv_dim + kv_head_offset;
const key = self.key_cache[kv_cache_head_offset..][0..head_size];
const key_value_cache_offset = prev_pos * key_value_size + key_value_head_offset;
const key = key_cache[key_value_cache_offset..][0..query_head_size];

scores[prev_pos] = lib.dot(query, key) / checkpoint.head_size_sqrt;
scores[prev_pos] = lib.dot(query, key) / checkpoint.query_head_size_sqrt;
}

lib.softmax(scores[0..(pos + 1)]);

const weighted_values = self.input_buffer[head_offset..][0..head_size];
const weighted_values = self.input_buffer[query_head_offset..][0..query_head_size];

@memset(weighted_values, 0);

for (0..(pos + 1)) |prev_pos| {
const kv_cache_head_offset = kv_cache_layer_offset + prev_pos * kv_dim + kv_head_offset;
const value = self.value_cache[kv_cache_head_offset..];
const key_value_cache_offset = prev_pos * key_value_size + key_value_head_offset;
const value = value_cache[key_value_cache_offset..];
const weight = scores[prev_pos];

for (0..head_size) |index| {
for (0..query_head_size) |index| {
weighted_values[index] += weight * value[index];
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ allocator: std.mem.Allocator,
transformer: Transformer,
tokenizer: Tokenizer,
sampler: Sampler,
n_steps: usize,
user_prompt: []const u8,
system_prompt: []const u8,

Expand All @@ -34,7 +33,6 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
.transformer = transformer,
.tokenizer = tokenizer,
.sampler = sampler,
.n_steps = cli.n_steps,
.user_prompt = cli.prompt,
.system_prompt = cli.system_prompt,
};
Expand Down Expand Up @@ -69,7 +67,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
allocator.free(prompt_tokens);
};

for (0..self.n_steps) |pos| {
for (0..self.transformer.sequence_length) |pos| {
try self.transformer.forward(token, pos);

if (token == bos_token and user_turn) {
Expand Down Expand Up @@ -131,7 +129,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.logits_vector);
next_token = self.sampler.sample(self.transformer.logits);
}

if (next_token == eos_token) {
Expand Down
Loading

0 comments on commit e8de3a7

Please sign in to comment.