Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 23, 2023
1 parent ebf217f commit 32f865d
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 102 deletions.
46 changes: 23 additions & 23 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn forward(self: Self, layer: usize, position: usize) !void {
try key_weight.multiplyVector(self.input, multi_key);
try value_weight.multiplyVector(self.input, multi_value);

self.computeRoPE(position, multi_key.values);
self.computeRoPE(position, multi_key.data);

for (0..self.checkpoint.n_attention_heads) |head| {
try self.computeGQA(layer, position, head);
Expand All @@ -77,37 +77,37 @@ pub fn forward(self: Self, layer: usize, position: usize) !void {
}

// Rotary positional embeddings: https://arxiv.org/abs/2104.09864
fn computeRoPE(self: Self, position: usize, multi_key_values: []f32) void {
fn computeRoPE(self: Self, position: usize, multi_key_data: []f32) void {
@setFloatMode(.Optimized);

const multi_query_values = self.multi_query.values;
const multi_query_data = self.multi_query.data;

std.debug.assert(multi_query_values.len % multi_key_values.len == 0);
std.debug.assert(multi_query_data.len % multi_key_data.len == 0);

var index: usize = 0;

while (index < multi_query_values.len) : (index += 2) {
while (index < multi_query_data.len) : (index += 2) {
const head: f32 = @floatFromInt(index % self.head_size);

const frequency =
1 / std.math.pow(f32, 10000, head / @as(f32, @floatFromInt(self.head_size)));

const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(position)) * frequency;
const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor);
const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor);
const real_rotation: f32 = std.math.cos(rotation_scaling_factor);
const imag_rotation: f32 = std.math.sin(rotation_scaling_factor);

const q_0 = multi_query_values[index];
const q_1 = multi_query_values[index + 1];
const q_0 = multi_query_data[index];
const q_1 = multi_query_data[index + 1];

multi_query_values[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value;
multi_query_values[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value;
multi_query_data[index] = q_0 * real_rotation - q_1 * imag_rotation;
multi_query_data[index + 1] = q_0 * imag_rotation + q_1 * real_rotation;

if (index < multi_key_values.len) {
const k_0 = multi_key_values[index];
const k_1 = multi_key_values[index + 1];
if (index < multi_key_data.len) {
const k_0 = multi_key_data[index];
const k_1 = multi_key_data[index + 1];

multi_key_values[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value;
multi_key_values[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value;
multi_key_data[index] = k_0 * real_rotation - k_1 * imag_rotation;
multi_key_data[index + 1] = k_0 * imag_rotation + k_1 * real_rotation;
}
}
}
Expand All @@ -116,7 +116,7 @@ fn computeRoPE(self: Self, position: usize, multi_key_values: []f32) void {
fn computeGQA(self: Self, layer: usize, current_position: usize, head: usize) !void {
@setFloatMode(.Optimized);

const query_values = self.multi_query.values[head * self.head_size ..][0..self.head_size];
const query_data = self.multi_query.data[head * self.head_size ..][0..self.head_size];

const query_group =
head / (self.checkpoint.n_attention_heads / self.checkpoint.n_attention_query_groups);
Expand All @@ -125,25 +125,25 @@ fn computeGQA(self: Self, layer: usize, current_position: usize, head: usize) !v

for (0..next_position) |position| {
const multi_key = self.key_cache[layer][position];
const key_values = multi_key.values[query_group * self.head_size ..][0..self.head_size];
const key_data = multi_key.data[query_group * self.head_size ..][0..self.head_size];

self.scores[position] =
try simd.computeScalarProduct(query_values, key_values) / self.head_size_sqrt;
try simd.computeScalarProduct(query_data, key_data) / self.head_size_sqrt;
}

math.softmax(self.scores[0..next_position]);

const attention_values = self.input.values[head * self.head_size ..][0..self.head_size];
const attention_data = self.input.data[head * self.head_size ..][0..self.head_size];

@memset(attention_values, 0);
@memset(attention_data, 0);

for (0..next_position) |position| {
const multi_value = self.value_cache[layer][position];
const value_values = multi_value.values[query_group * self.head_size ..][0..self.head_size];
const value_data = multi_value.data[query_group * self.head_size ..][0..self.head_size];
const weight = self.scores[position];

for (0..self.head_size) |index| {
attention_values[index] += value_values[index] * weight;
attention_data[index] += value_data[index] * weight;
}
}
}
2 changes: 1 addition & 1 deletion src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.output.values);
next_token = self.sampler.sample(self.transformer.output.data);
}

if (next_token == eos_token) {
Expand Down
2 changes: 1 addition & 1 deletion src/ffn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn forward(self: Self, layer: usize) !void {
try up_weight.multiplyVector(self.input, self.hidden);

for (0..self.checkpoint.ffn_hidden_size) |index| {
self.hidden.values[index] *= swish(self.gate.values[index]);
self.hidden.data[index] *= swish(self.gate.data[index]);
}

try down_weight.multiplyVector(self.hidden, self.output);
Expand Down
2 changes: 1 addition & 1 deletion src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub fn generate(self: *Self, writer: anytype) !void {
next_token = self.prompt_tokens[prompt_tokens_index];
prompt_tokens_index += 1;
} else {
next_token = self.sampler.sample(self.transformer.output.values);
next_token = self.sampler.sample(self.transformer.output.data);
}

if (next_token == bos_token or next_token == eos_token) {
Expand Down
26 changes: 13 additions & 13 deletions src/math.zig
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
const std = @import("std");

pub fn argmax(values: []f32) usize {
pub fn argmax(data: []f32) usize {
var max_index: usize = 0;
var max_value: f32 = values[max_index];
var max_element: f32 = data[max_index];

for (1..values.len) |index| {
const value = values[index];
for (1..data.len) |index| {
const element = data[index];

if (value > max_value) {
if (element > max_element) {
max_index = index;
max_value = value;
max_element = element;
}
}

return max_index;
}

pub fn softmax(values: []f32) void {
pub fn softmax(data: []f32) void {
@setFloatMode(.Optimized);

var max_value: f32 = std.mem.max(f32, values);
var max_element: f32 = std.mem.max(f32, data);
var sum: f32 = 0;

for (values) |*value| {
value.* = std.math.exp(value.* - max_value);
sum += value.*;
for (data) |*element| {
element.* = std.math.exp(element.* - max_element);
sum += element.*;
}

for (values) |*value| {
value.* /= sum;
for (data) |*element| {
element.* /= sum;
}
}
22 changes: 11 additions & 11 deletions src/matrix.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,45 @@ const max_thread_count = 24;

pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void {
if (self.thread_count == 0) {
try computeMatrixVectorMultiplication(self.rows, input, output.values);
try computeMatrixVectorMultiplication(self.rows, input, output.data);

return;
}

const n_threads = @min(max_thread_count, self.thread_count);
const thread_chunk_size = output.values.len / n_threads;
const chunk_size = output.data.len / n_threads;

var threads: [max_thread_count]std.Thread = undefined;

for (threads[0..n_threads], 0..) |*thread, index| {
thread.* = try std.Thread.spawn(.{}, computeMatrixVectorMultiplication, .{
self.rows[index * thread_chunk_size ..][0..thread_chunk_size],
self.rows[index * chunk_size ..][0..chunk_size],
input,
output.values[index * thread_chunk_size ..][0..thread_chunk_size],
output.data[index * chunk_size ..][0..chunk_size],
});
}

for (threads[0..n_threads]) |thread| {
thread.join();
}

if (output.values.len % n_threads > 0) {
if (output.data.len % n_threads > 0) {
try computeMatrixVectorMultiplication(
self.rows[n_threads * thread_chunk_size ..],
self.rows[n_threads * chunk_size ..],
input,
output.values[n_threads * thread_chunk_size ..],
output.data[n_threads * chunk_size ..],
);
}
}

fn computeMatrixVectorMultiplication(
rows: []const Vector,
input: Vector,
output_values: []f32,
output_data: []f32,
) !void {
std.debug.assert(rows.len == output_values.len);
std.debug.assert(rows.len == output_data.len);

for (output_values, 0..) |*value, index| {
value.* = try rows[index].computeScalarProduct(input);
for (output_data, 0..) |*element, index| {
element.* = try rows[index].computeScalarProduct(input);
}
}
70 changes: 33 additions & 37 deletions src/simd.zig
Original file line number Diff line number Diff line change
@@ -1,67 +1,63 @@
const std = @import("std");

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(
input_values: []const f32,
weight_values: []const f32,
output_values: []f32,
) !void {
pub fn computeRMSNorm(input_data: []const f32, weight_data: []const f32, output_data: []f32) !void {
@setFloatMode(.Optimized);

var scaling_factor = try computeScalarProduct(input_values, input_values);
var scaling_factor = try computeScalarProduct(input_data, input_data);

scaling_factor /= @floatFromInt(input_values.len);
scaling_factor /= @floatFromInt(input_data.len);
scaling_factor += 1e-5;
scaling_factor = 1 / std.math.sqrt(scaling_factor);

try computeVectorMultiplication(scaling_factor, input_values, weight_values, output_values);
try computeVectorMultiplication(scaling_factor, input_data, weight_data, output_data);
}

pub fn computeScalarProduct(input_values_1: []const f32, input_values_2: []const f32) !f32 {
pub fn computeScalarProduct(input_data_1: []const f32, input_data_2: []const f32) !f32 {
@setFloatMode(.Optimized);

std.debug.assert(input_values_1.len == input_values_2.len);
std.debug.assert(input_data_1.len == input_data_2.len);

comptime var vector_len = std.atomic.cache_line / @sizeOf(f32);

inline while (vector_len >= 4) : (vector_len /= 2) {
if (input_values_1.len % vector_len == 0) {
var output_values: @Vector(vector_len, f32) = @splat(0);
if (input_data_1.len % vector_len == 0) {
var output_data: @Vector(vector_len, f32) = @splat(0);
var index: usize = 0;

while (index < input_values_1.len) : (index += vector_len) {
output_values +=
@as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) *
@as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*);
while (index < input_data_1.len) : (index += vector_len) {
output_data +=
@as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) *
@as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*);
}

return @reduce(.Add, output_values);
return @reduce(.Add, output_data);
}
}

return error.UnsupportedVectorSize;
}

pub fn computeVectorAddition(
input_values_1: []const f32,
input_values_2: []const f32,
output_values: []f32,
input_data_1: []const f32,
input_data_2: []const f32,
output_data: []f32,
) !void {
@setFloatMode(.Optimized);

std.debug.assert(input_values_1.len == input_values_2.len);
std.debug.assert(input_values_1.len == output_values.len);
std.debug.assert(input_data_1.len == input_data_2.len);
std.debug.assert(input_data_1.len == output_data.len);

comptime var vector_len = std.atomic.cache_line / @sizeOf(f32);

inline while (vector_len >= 4) : (vector_len /= 2) {
if (input_values_1.len % vector_len == 0) {
if (input_data_1.len % vector_len == 0) {
var index: usize = 0;

while (index < input_values_1.len) : (index += vector_len) {
output_values[index..][0..vector_len].* =
@as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) +
@as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*);
while (index < input_data_1.len) : (index += vector_len) {
output_data[index..][0..vector_len].* =
@as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) +
@as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*);
}

return;
Expand All @@ -73,28 +69,28 @@ pub fn computeVectorAddition(

pub fn computeVectorMultiplication(
scaling_factor: f32,
input_values_1: []const f32,
input_values_2: []const f32,
output_values: []f32,
input_data_1: []const f32,
input_data_2: []const f32,
output_data: []f32,
) !void {
@setFloatMode(.Optimized);

std.debug.assert(input_values_1.len == input_values_2.len);
std.debug.assert(input_values_1.len == output_values.len);
std.debug.assert(input_data_1.len == input_data_2.len);
std.debug.assert(input_data_1.len == output_data.len);

comptime var vector_len = std.atomic.cache_line / @sizeOf(f32);

inline while (vector_len >= 4) : (vector_len /= 2) {
if (input_values_1.len % vector_len == 0) {
if (input_data_1.len % vector_len == 0) {
const scaling_factors: @Vector(vector_len, f32) = @splat(scaling_factor);

var index: usize = 0;

while (index < input_values_1.len) : (index += vector_len) {
output_values[index..][0..vector_len].* =
while (index < input_data_1.len) : (index += vector_len) {
output_data[index..][0..vector_len].* =
scaling_factors *
@as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) *
@as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*);
@as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) *
@as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*);
}

return;
Expand Down
2 changes: 1 addition & 1 deletion src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn createLeaky(allocator: std.mem.Allocator, args: anytype) !Self {
pub fn forward(self: Self, token: usize, position: usize) !void {
const embedding_weight = self.checkpoint.embedding_weights[token];

@memcpy(self.hidden.values, embedding_weight.values);
@memcpy(self.hidden.data, embedding_weight.data);

for (0..self.checkpoint.n_layers) |layer| {
const attention_norm_weight = self.checkpoint.attention_norm_weights[layer];
Expand Down
Loading

0 comments on commit 32f865d

Please sign in to comment.