From 92d0fefd303c59fac969bb23e0a101530842d031 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Tue, 12 Sep 2023 10:28:53 +0200 Subject: [PATCH] Minor refactoring --- src/attention.zig | 10 +++++----- src/matrix.zig | 6 ++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/attention.zig b/src/attention.zig index d67f9a8..752ad36 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -110,8 +110,8 @@ pub fn forward(self: *const Self, layer: usize, position: usize) !void { const output_projection_matrix = self.output_projection_matrices[layer]; const multi_head_query = self.multi_head_query; - const multi_head_key = self.getCacheSlice(.key, layer, position, null); - const multi_head_value = self.getCacheSlice(.value, layer, position, null); + const multi_head_key = self.sliceCache(.key, layer, position, null); + const multi_head_value = self.sliceCache(.value, layer, position, null); query_projection_matrix.multiplyVector(self.input_vector, multi_head_query); key_projection_matrix.multiplyVector(self.input_vector, multi_head_key); @@ -180,7 +180,7 @@ fn computeGroupedQueryAttention( const next_position = current_position + 1; for (0..next_position) |position| { - const key_vector = self.getCacheSlice(.key, layer, position, query_group); + const key_vector = self.sliceCache(.key, layer, position, query_group); self.scores[position] = vector.dot(query_vector, key_vector) / self.head_size_sqrt; } @@ -192,7 +192,7 @@ fn computeGroupedQueryAttention( @memset(attention_values, 0); for (0..next_position) |position| { - const value_vector = self.getCacheSlice(.value, layer, position, query_group); + const value_vector = self.sliceCache(.value, layer, position, query_group); const weight = self.scores[position]; @@ -204,7 +204,7 @@ fn computeGroupedQueryAttention( const CacheType = enum { key, value }; -fn getCacheSlice( +fn sliceCache( self: *const Self, cache_type: CacheType, layer: usize, diff --git a/src/matrix.zig b/src/matrix.zig index d77b23e..7744cc8 100644 --- a/src/matrix.zig +++ b/src/matrix.zig @@ -46,12 +46,14 @@ pub fn slice( } pub fn multiplyVector(self: *const Self, input_vector: []const f32, output_vector: []f32) void { + const m_rows = self.m_rows; const n_cols = self.n_cols; + const row_major_data = self.row_major_data; std.debug.assert(input_vector.len == n_cols); - std.debug.assert(output_vector.len == self.m_rows); + std.debug.assert(output_vector.len == m_rows); for (output_vector, 0..) |*element, row| { - element.* = vector.dot(self.row_major_data[(row * n_cols)..][0..n_cols], input_vector); + element.* = vector.dot(row_major_data[(row * n_cols)..][0..n_cols], input_vector); } }