Skip to content

Commit

Permalink
Improve SIMD utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 20, 2023
1 parent 809f0b4 commit e3a3880
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 66 deletions.
97 changes: 97 additions & 0 deletions src/simd.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
const std = @import("std");

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(
comptime TValue: type,
comptime vector_size: comptime_int,
input_values: []const TValue,
weight_values: []const TValue,
output_values: []TValue,
) void {
@setFloatMode(.Optimized);

var rms_scaling_factor = computeScalarProduct(TValue, vector_size, input_values, input_values);

rms_scaling_factor /= @floatFromInt(input_values.len);
rms_scaling_factor += 1e-5;
rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor);

computeVectorMultiplication(
TValue,
vector_size,
rms_scaling_factor,
input_values,
weight_values,
output_values,
);
}

pub fn computeScalarProduct(
comptime TValue: type,
comptime vector_size: comptime_int,
values_1: []const TValue,
values_2: []const TValue,
) f32 {
@setFloatMode(.Optimized);

std.debug.assert(values_1.len == values_2.len);
std.debug.assert(values_1.len % vector_size == 0);

var output_values: @Vector(vector_size, f32) = @splat(0.0);
var index: usize = 0;

while (index < values_1.len) : (index += vector_size) {
output_values +=
@as(@Vector(vector_size, f32), values_1[index..][0..vector_size].*) *
@as(@Vector(vector_size, f32), values_2[index..][0..vector_size].*);
}

return @reduce(.Add, output_values);
}

pub fn computeVectorAddition(
comptime TValue: type,
comptime vector_size: comptime_int,
input_values_1: []const TValue,
input_values_2: []const TValue,
output_values: []TValue,
) void {
@setFloatMode(.Optimized);

std.debug.assert(input_values_1.len == input_values_2.len);
std.debug.assert(input_values_1.len % vector_size == 0);

var index: usize = 0;

while (index < input_values_1.len) : (index += vector_size) {
output_values[index..][0..vector_size].* =
@as(@Vector(vector_size, TValue), input_values_1[index..][0..vector_size].*) +
@as(@Vector(vector_size, TValue), input_values_2[index..][0..vector_size].*);
}
}

pub fn computeVectorMultiplication(
comptime TValue: type,
comptime vector_size: comptime_int,
scaling_factor: f32,
input_values_1: []const TValue,
input_values_2: []const TValue,
output_values: []TValue,
) void {
@setFloatMode(.Optimized);

std.debug.assert(input_values_1.len == input_values_2.len);
std.debug.assert(input_values_1.len == output_values.len);
std.debug.assert(input_values_1.len % vector_size == 0);

const scaling_factors: @Vector(vector_size, f32) = @splat(scaling_factor);

var index: usize = 0;

while (index < input_values_1.len) : (index += vector_size) {
output_values[index..][0..vector_size].* =
scaling_factors *
@as(@Vector(vector_size, TValue), input_values_1[index..][0..vector_size].*) *
@as(@Vector(vector_size, TValue), input_values_2[index..][0..vector_size].*);
}
}
93 changes: 29 additions & 64 deletions src/tensor.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const std = @import("std");
const simd = @import("simd.zig");

pub fn Tensor(comptime n_dims: comptime_int) type {
comptime if (n_dims < 1) @compileError("n_dims < 1");
Expand Down Expand Up @@ -50,79 +51,43 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
};
}

pub fn add(self: Self, other: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(self.values.len == other.values.len);

for (self.values, 0..) |*value, index| {
value.* += other.values[index];
}
}

pub fn computeMatrixVectorMultiplication(self: Self, input: anytype, output: anytype) void {
for (output.values, 0..) |*value, index| {
value.* = self.slice(index).computeScalarProduct(input);
}
}

pub fn computeScalarProduct(self: Self, other: anytype) f32 {
if (self.values.len % 32 == 0) {
return _computeScalarProduct(32, self, other);
}

if (self.values.len % 16 == 0) {
return _computeScalarProduct(16, self, other);
}

if (self.values.len % 8 == 0) {
return _computeScalarProduct(8, self, other);
}

return _computeScalarProduct(4, self, other);
}

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(self: Self, weight: anytype, output: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(output.values.len == self.values.len);
std.debug.assert(output.values.len == weight.values.len);

var rms_scaling_factor: f32 = 0;

for (self.values) |value| {
rms_scaling_factor += value * value;
}
if (self.values.len % 32 == 0)
simd.computeRMSNorm(f32, 32, self.values, weight.values, output.values)
else if (self.values.len % 16 == 0)
simd.computeRMSNorm(f32, 16, self.values, weight.values, output.values)
else if (self.values.len % 8 == 0)
simd.computeRMSNorm(f32, 8, self.values, weight.values, output.values)
else
simd.computeRMSNorm(f32, 4, self.values, weight.values, output.values);
}

rms_scaling_factor /= @floatFromInt(self.values.len);
rms_scaling_factor += 1e-5;
rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor);
pub fn computeScalarProduct(self: Self, other: anytype) f32 {
return if (self.values.len % 32 == 0)
simd.computeScalarProduct(f32, 32, self.values, other.values)
else if (self.values.len % 16 == 0)
simd.computeScalarProduct(f32, 16, self.values, other.values)
else if (self.values.len % 8 == 0)
simd.computeScalarProduct(f32, 8, self.values, other.values)
else
simd.computeScalarProduct(f32, 4, self.values, other.values);
}

for (output.values, 0..) |*value, index| {
value.* = weight.values[index] * rms_scaling_factor * self.values[index];
}
pub fn computeVectorAddition(self: Self, other: anytype) void {
if (self.values.len % 32 == 0)
simd.computeVectorAddition(f32, 32, self.values, other.values, self.values)
else if (self.values.len % 16 == 0)
simd.computeVectorAddition(f32, 16, self.values, other.values, self.values)
else if (self.values.len % 8 == 0)
simd.computeVectorAddition(f32, 8, self.values, other.values, self.values)
else
simd.computeVectorAddition(f32, 4, self.values, other.values, self.values);
}
};
}

fn _computeScalarProduct(
comptime vector_size: comptime_int,
input_1: anytype,
input_2: anytype,
) f32 {
@setFloatMode(.Optimized);

std.debug.assert(input_1.values.len == input_2.values.len);

var output_values: @Vector(vector_size, f32) = @splat(0.0);
var index: usize = 0;

while (index < input_1.values.len) : (index += vector_size) {
output_values +=
@as(@Vector(vector_size, f32), input_1.values[index..][0..vector_size].*) *
@as(@Vector(vector_size, f32), input_2.values[index..][0..vector_size].*);
}

return @reduce(.Add, output_values);
}
4 changes: 2 additions & 2 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ pub fn forward(self: Self, token: usize, position: usize) void {
);

self.attention.forward(layer, position);
self.hidden_buffer.add(self.attention.output_buffer);
self.hidden_buffer.computeVectorAddition(self.attention.output_buffer);

self.hidden_buffer.computeRMSNorm(
weights.ffn_norm_vectors.slice(layer),
self.ffn.input_buffer,
);

self.ffn.forward(layer);
self.hidden_buffer.add(self.ffn.output_buffer);
self.hidden_buffer.computeVectorAddition(self.ffn.output_buffer);
}

self.hidden_buffer.computeRMSNorm(weights.output_norm_vector, self.hidden_buffer);
Expand Down

0 comments on commit e3a3880

Please sign in to comment.