Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 12, 2023
1 parent 86d613e commit 14bcf90
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/vector.zig
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ pub fn dot(input_a: []const f32, input_b: []const f32) f32 {
}

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void {
pub fn rmsnorm(input: []const f32, weights: []const f32, output: []f32) void {
@setFloatMode(.Optimized);

std.debug.assert(output.len == input.len);
std.debug.assert(output.len == weight.len);
std.debug.assert(output.len == weights.len);

var rms_scaling_factor: f32 = 0;

Expand All @@ -73,7 +73,7 @@ pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void {
rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor);

for (output, 0..) |*element, index| {
element.* = weight[index] * rms_scaling_factor * input[index];
element.* = weights[index] * rms_scaling_factor * input[index];
}
}

Expand Down

0 comments on commit 14bcf90

Please sign in to comment.