Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Sep 12, 2023
1 parent 381f3e5 commit 92d0fef
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ pub fn forward(self: *const Self, layer: usize, position: usize) !void {
const output_projection_matrix = self.output_projection_matrices[layer];

const multi_head_query = self.multi_head_query;
const multi_head_key = self.getCacheSlice(.key, layer, position, null);
const multi_head_value = self.getCacheSlice(.value, layer, position, null);
const multi_head_key = self.sliceCache(.key, layer, position, null);
const multi_head_value = self.sliceCache(.value, layer, position, null);

query_projection_matrix.multiplyVector(self.input_vector, multi_head_query);
key_projection_matrix.multiplyVector(self.input_vector, multi_head_key);
Expand Down Expand Up @@ -180,7 +180,7 @@ fn computeGroupedQueryAttention(
const next_position = current_position + 1;

for (0..next_position) |position| {
const key_vector = self.getCacheSlice(.key, layer, position, query_group);
const key_vector = self.sliceCache(.key, layer, position, query_group);

self.scores[position] = vector.dot(query_vector, key_vector) / self.head_size_sqrt;
}
Expand All @@ -192,7 +192,7 @@ fn computeGroupedQueryAttention(
@memset(attention_values, 0);

for (0..next_position) |position| {
const value_vector = self.getCacheSlice(.value, layer, position, query_group);
const value_vector = self.sliceCache(.value, layer, position, query_group);

const weight = self.scores[position];

Expand All @@ -204,7 +204,7 @@ fn computeGroupedQueryAttention(

const CacheType = enum { key, value };

fn getCacheSlice(
fn sliceCache(
self: *const Self,
cache_type: CacheType,
layer: usize,
Expand Down
6 changes: 4 additions & 2 deletions src/matrix.zig
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ pub fn slice(
}

pub fn multiplyVector(self: *const Self, input_vector: []const f32, output_vector: []f32) void {
const m_rows = self.m_rows;
const n_cols = self.n_cols;
const row_major_data = self.row_major_data;

std.debug.assert(input_vector.len == n_cols);
std.debug.assert(output_vector.len == self.m_rows);
std.debug.assert(output_vector.len == m_rows);

for (output_vector, 0..) |*element, row| {
element.* = vector.dot(self.row_major_data[(row * n_cols)..][0..n_cols], input_vector);
element.* = vector.dot(row_major_data[(row * n_cols)..][0..n_cols], input_vector);
}
}

0 comments on commit 92d0fef

Please sign in to comment.