Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 12, 2023
1 parent 7825892 commit a50ffac
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 48 deletions.
3 changes: 1 addition & 2 deletions src/ffn.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// SwiGLU activation function: https://arxiv.org/abs/2002.05202

const Self = @This();

const std = @import("std");
Expand Down Expand Up @@ -47,6 +45,7 @@ pub fn deinit(self: *const Self) void {
self.output_buffer.deinit();
}

// SwiGLU activation function: https://arxiv.org/abs/2002.05202
pub fn forward(self: *const Self, layer: usize) !void {
@setFloatMode(.Optimized);

Expand Down
13 changes: 5 additions & 8 deletions src/tensor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,17 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
};
}

pub fn multiplyVector(self: *const Self, input_data: []const f32, output_data: []f32) void {
pub fn multiplyVector(self: *const Self, input: []const f32, output: []f32) void {
comptime if (n_dims < 2) @compileError("n_dims < 2");

const data = self.data;
const sub_tensor_size = self.sub_tensor_sizes[0];

std.debug.assert(input_data.len == sub_tensor_size);
std.debug.assert(output_data.len == data.len / sub_tensor_size);
std.debug.assert(input.len == sub_tensor_size);
std.debug.assert(output.len == data.len / sub_tensor_size);

for (output_data, 0..) |*value, index| {
value.* = vector.dot(
data[(index * sub_tensor_size)..][0..sub_tensor_size],
input_data,
);
for (output, 0..) |*value, index| {
value.* = vector.dot(data[(index * sub_tensor_size)..][0..sub_tensor_size], input);
}
}
};
Expand Down
58 changes: 20 additions & 38 deletions src/vector.zig
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
const std = @import("std");

pub fn slice(
comptime T: type,
allocator: std.mem.Allocator,
vector_size: usize,
data: T,
) ![]T {
std.debug.assert(data.len % vector_size == 0);

var vectors = try allocator.alloc(T, data.len / vector_size);

for (vectors, 0..) |*vector, vector_index| {
vector.* = data[(vector_index * vector_size)..][0..vector_size];
}

return vectors;
}

pub fn add(vector_a: []f32, vector_b: []const f32) void {
pub fn add(input_a: []f32, input_b: []const f32) void {
@setFloatMode(.Optimized);

std.debug.assert(vector_a.len == vector_b.len);
std.debug.assert(input_a.len == input_b.len);

for (vector_a, 0..) |*element, index| {
element.* += vector_b[index];
for (input_a, 0..) |*element, index| {
element.* += input_b[index];
}
}

pub fn argmax(vector: []f32) usize {
pub fn argmax(input: []f32) usize {
var max_index: usize = 0;
var max_element: f32 = vector[max_index];
var max_element: f32 = input[max_index];

for (1..vector.len) |index| {
const element = vector[index];
for (1..input.len) |index| {
const element = input[index];

if (element > max_element) {
max_index = index;
Expand All @@ -43,14 +26,13 @@ pub fn argmax(vector: []f32) usize {
return max_index;
}

pub fn dot(vector_a: []const f32, vector_b: []const f32) f32 {
pub fn dot(input_a: []const f32, input_b: []const f32) f32 {
@setFloatMode(.Optimized);

// const native_vector_size: usize = comptime @max(std.simd.suggestVectorSize(f32) orelse 4, 4);
const native_vector_size: usize = 4; // TODO: the above code does not run on GitHub CI
const native_vector_size: usize = 4;

std.debug.assert(vector_a.len == vector_b.len);
std.debug.assert(vector_a.len % native_vector_size == 0);
std.debug.assert(input_a.len == input_b.len);
std.debug.assert(input_a.len % native_vector_size == 0);

var result: f32 = 0;
var offset: usize = 0;
Expand All @@ -59,12 +41,12 @@ pub fn dot(vector_a: []const f32, vector_b: []const f32) f32 {

inline while (vector_size >= native_vector_size) : (vector_size /= native_vector_size) {
var vector: @Vector(vector_size, f32) = @splat(0.0);
var rest = (vector_a.len - offset) % vector_size;
var rest = (input_a.len - offset) % vector_size;

while (offset < vector_a.len - rest) : (offset += vector_size) {
while (offset < input_a.len - rest) : (offset += vector_size) {
vector +=
@as(@Vector(vector_size, f32), vector_a[offset..][0..vector_size].*) *
@as(@Vector(vector_size, f32), vector_b[offset..][0..vector_size].*);
@as(@Vector(vector_size, f32), input_a[offset..][0..vector_size].*) *
@as(@Vector(vector_size, f32), input_b[offset..][0..vector_size].*);
}

result += @reduce(.Add, vector);
Expand Down Expand Up @@ -95,18 +77,18 @@ pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void {
}
}

pub fn softmax(vector: []f32) void {
pub fn softmax(input: []f32) void {
@setFloatMode(.Optimized);

var max_element: f32 = std.mem.max(f32, vector);
var max_element: f32 = std.mem.max(f32, input);
var sum: f32 = 0;

for (vector) |*element| {
for (input) |*element| {
element.* = std.math.exp(element.* - max_element);
sum += element.*;
}

for (vector) |*element| {
for (input) |*element| {
element.* /= sum;
}
}

0 comments on commit a50ffac

Please sign in to comment.