Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Sep 12, 2023
1 parent 92d0fef commit 162f1b8
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/vector.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ pub fn slice(
return vectors;
}

pub fn add(a: []f32, b: []const f32) void {
pub fn add(vector_a: []f32, vector_b: []const f32) void {
@setFloatMode(.Optimized);

std.debug.assert(a.len == b.len);
std.debug.assert(vector_a.len == vector_b.len);

for (a, 0..) |*element, index| {
element.* += b[index];
for (vector_a, 0..) |*element, index| {
element.* += vector_b[index];
}
}

Expand All @@ -43,14 +43,14 @@ pub fn argmax(vector: []f32) usize {
return max_index;
}

pub fn dot(a: []const f32, b: []const f32) f32 {
pub fn dot(vector_a: []const f32, vector_b: []const f32) f32 {
@setFloatMode(.Optimized);

// const native_vector_size: usize = comptime @max(std.simd.suggestVectorSize(f32) orelse 4, 4);
const native_vector_size: usize = 4; // TODO: the above code does not run on GitHub CI

std.debug.assert(a.len == b.len);
std.debug.assert(a.len % native_vector_size == 0);
std.debug.assert(vector_a.len == vector_b.len);
std.debug.assert(vector_a.len % native_vector_size == 0);

var result: f32 = 0;
var offset: usize = 0;
Expand All @@ -59,12 +59,12 @@ pub fn dot(a: []const f32, b: []const f32) f32 {

inline while (vector_size >= native_vector_size) : (vector_size /= native_vector_size) {
var vector: @Vector(vector_size, f32) = @splat(0.0);
var rest = (a.len - offset) % vector_size;
var rest = (vector_a.len - offset) % vector_size;

while (offset < a.len - rest) : (offset += vector_size) {
while (offset < vector_a.len - rest) : (offset += vector_size) {
vector +=
@as(@Vector(vector_size, f32), a[offset..][0..vector_size].*) *
@as(@Vector(vector_size, f32), b[offset..][0..vector_size].*);
@as(@Vector(vector_size, f32), vector_a[offset..][0..vector_size].*) *
@as(@Vector(vector_size, f32), vector_b[offset..][0..vector_size].*);
}

result += @reduce(.Add, vector);
Expand Down

0 comments on commit 162f1b8

Please sign in to comment.