Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent ef3860f commit b3554e3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ pub fn forward(self: *const Self, layer: usize, position: usize) void {
key_matrix.computeMatrixVectorMultiplication(self.input_buffer, key_buffer);
value_matrix.computeMatrixVectorMultiplication(self.input_buffer, value_buffer);

self.rope(position, key_buffer);
self.computeRoPE(position, key_buffer);

for (0..self.checkpoint.n_attention_heads) |head| {
self.gqa(layer, position, head);
self.computeGQA(layer, position, head);
}

output_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.output_buffer);
}

// Rotary positional embeddings: https://arxiv.org/abs/2104.09864
fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
fn computeRoPE(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
@setFloatMode(.Optimized);

std.debug.assert(self.query_buffer.values.len % key_buffer.values.len == 0);
Expand Down Expand Up @@ -133,7 +133,7 @@ fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
}

// Grouped-query attention: https://arxiv.org/abs/2305.13245v1
fn gqa(self: *const Self, layer: usize, current_position: usize, head: usize) void {
fn computeGQA(self: *const Self, layer: usize, current_position: usize, head: usize) void {
@setFloatMode(.Optimized);

const query_vector = self.query_buffer.slice(head);
Expand Down

0 comments on commit b3554e3

Please sign in to comment.