Skip to content

Commit

Permalink
Improve Tensor struct
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent fa0e8a5 commit 84f4e7f
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 173 deletions.
43 changes: 21 additions & 22 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ const Self = @This();

const std = @import("std");
const Checkpoint = @import("checkpoint.zig");
const math = @import("./math.zig");
const Tensor = @import("./tensor.zig").Tensor;
const vector = @import("./vector.zig");

allocator: std.mem.Allocator,
checkpoint: Checkpoint,
Expand Down Expand Up @@ -76,7 +76,7 @@ pub fn deinit(self: *const Self) void {
self.allocator.free(self.scores);
}

pub fn forward(self: *const Self, layer: usize, position: usize) !void {
pub fn forward(self: *const Self, layer: usize, position: usize) void {
const weights = self.checkpoint.weights;
const query_matrix = weights.attention_query_matrices.slice(layer);
const key_matrix = weights.attention_key_matrices.slice(layer);
Expand All @@ -85,28 +85,28 @@ pub fn forward(self: *const Self, layer: usize, position: usize) !void {
const key_buffer = self.key_cache.slice(layer).slice(position);
const value_buffer = self.value_cache.slice(layer).slice(position);

query_matrix.multiplyVector(self.input_buffer, self.query_buffer);
key_matrix.multiplyVector(self.input_buffer, key_buffer);
value_matrix.multiplyVector(self.input_buffer, value_buffer);
query_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.query_buffer);
key_matrix.computeMatrixVectorMultiplication(self.input_buffer, key_buffer);
value_matrix.computeMatrixVectorMultiplication(self.input_buffer, value_buffer);

self.rope(position, key_buffer);

for (0..self.checkpoint.n_attention_heads) |head| {
self.gqa(layer, position, head);
}

output_matrix.multiplyVector(self.input_buffer, self.output_buffer);
output_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.output_buffer);
}

// Rotary positional embeddings: https://arxiv.org/abs/2104.09864
fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
@setFloatMode(.Optimized);

std.debug.assert(self.query_buffer.data.len % key_buffer.data.len == 0);
std.debug.assert(self.query_buffer.values.len % key_buffer.values.len == 0);

var index: usize = 0;

while (index < self.query_buffer.data.len) : (index += 2) {
while (index < self.query_buffer.values.len) : (index += 2) {
const head: f32 = @floatFromInt(index % self.head_size);

const frequency =
Expand All @@ -116,18 +116,18 @@ fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor);
const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor);

const q_0 = self.query_buffer.data[index];
const q_1 = self.query_buffer.data[index + 1];
const q_0 = self.query_buffer.values[index];
const q_1 = self.query_buffer.values[index + 1];

self.query_buffer.data[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value;
self.query_buffer.data[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value;
self.query_buffer.values[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value;
self.query_buffer.values[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value;

if (index < key_buffer.data.len) {
const k_0 = key_buffer.data[index];
const k_1 = key_buffer.data[index + 1];
if (index < key_buffer.values.len) {
const k_0 = key_buffer.values[index];
const k_1 = key_buffer.values[index + 1];

key_buffer.data[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value;
key_buffer.data[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value;
key_buffer.values[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value;
key_buffer.values[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value;
}
}
}
Expand All @@ -146,22 +146,21 @@ fn gqa(self: *const Self, layer: usize, current_position: usize, head: usize) vo
for (0..next_position) |position| {
const key_vector = self.key_cache.slice(layer).slice(position).slice(query_group);

self.scores[position] =
vector.dot(query_vector.data, key_vector.data) / self.head_size_sqrt;
self.scores[position] = query_vector.computeScalarProduct(key_vector) / self.head_size_sqrt;
}

vector.softmax(self.scores[0..next_position]);
math.softmax(self.scores[0..next_position]);

const attention_buffer = self.input_buffer.slice(head);

@memset(attention_buffer.data, 0);
@memset(attention_buffer.values, 0);

for (0..next_position) |position| {
const value_vector = self.value_cache.slice(layer).slice(position).slice(query_group);
const weight = self.scores[position];

for (0..self.head_size) |index| {
attention_buffer.data[index] += value_vector.data[index] * weight;
attention_buffer.values[index] += value_vector.values[index] * weight;
}
}
}
4 changes: 2 additions & 2 deletions src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
};

for (0..self.transformer.sequence_length) |position| {
try self.transformer.forward(token, position);
self.transformer.forward(token, position);

if (token == bos_token and user_turn) {
var user_prompt = std.ArrayList(u8).init(allocator);
Expand Down Expand Up @@ -129,7 +129,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.output_buffer.data);
next_token = self.sampler.sample(self.transformer.output_buffer.values);
}

if (next_token == eos_token) {
Expand Down
1 change: 0 additions & 1 deletion src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ const Self = @This();

const std = @import("std");
const Tensor = @import("./tensor.zig").Tensor;
const vector = @import("./vector.zig");

allocator: std.mem.Allocator,
embedding_size: usize,
Expand Down
10 changes: 5 additions & 5 deletions src/ffn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ pub fn deinit(self: *const Self) void {
}

// SwiGLU activation function: https://arxiv.org/abs/2002.05202
pub fn forward(self: *const Self, layer: usize) !void {
pub fn forward(self: *const Self, layer: usize) void {
@setFloatMode(.Optimized);

const weights = self.checkpoint.weights;
const gate_matrix = weights.ffn_gate_matrices.slice(layer);
const up_matrix = weights.ffn_up_matrices.slice(layer);
const down_matrix = weights.ffn_down_matrices.slice(layer);

gate_matrix.multiplyVector(self.input_buffer, self.gate_buffer);
up_matrix.multiplyVector(self.input_buffer, self.hidden_buffer);
gate_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.gate_buffer);
up_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.hidden_buffer);

for (0..self.checkpoint.ffn_hidden_size) |index| {
self.hidden_buffer.data[index] *= swish(self.gate_buffer.data[index]);
self.hidden_buffer.values[index] *= swish(self.gate_buffer.values[index]);
}

down_matrix.multiplyVector(self.hidden_buffer, self.output_buffer);
down_matrix.computeMatrixVectorMultiplication(self.hidden_buffer, self.output_buffer);
}

// Swish activation function: https://arxiv.org/abs/1710.05941
Expand Down
4 changes: 2 additions & 2 deletions src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn generate(self: *Self, writer: anytype) !void {
start_time = std.time.milliTimestamp();
}

try self.transformer.forward(token, position);
self.transformer.forward(token, position);

if (start_time > 0) {
total_time += std.time.milliTimestamp() - start_time;
Expand All @@ -74,7 +74,7 @@ pub fn generate(self: *Self, writer: anytype) !void {
next_token = self.prompt_tokens[prompt_tokens_index];
prompt_tokens_index += 1;
} else {
next_token = self.sampler.sample(self.transformer.output_buffer.data);
next_token = self.sampler.sample(self.transformer.output_buffer.values);
}

if (next_token == bos_token or next_token == eos_token) {
Expand Down
33 changes: 33 additions & 0 deletions src/math.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
const std = @import("std");

pub fn argmax(values: []f32) usize {
var max_index: usize = 0;
var max_value: f32 = values[max_index];

for (1..values.len) |index| {
const value = values[index];

if (value > max_value) {
max_index = index;
max_value = value;
}
}

return max_index;
}

pub fn softmax(values: []f32) void {
@setFloatMode(.Optimized);

var max_value: f32 = std.mem.max(f32, values);
var sum: f32 = 0;

for (values) |*value| {
value.* = std.math.exp(value.* - max_value);
sum += value.*;
}

for (values) |*value| {
value.* /= sum;
}
}
6 changes: 3 additions & 3 deletions src/sampler.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const Self = @This();

const builtin = @import("builtin");
const std = @import("std");
const vector = @import("vector.zig");
const math = @import("math.zig");

allocator: std.mem.Allocator,
probability_index_pairs_buffer: []ProbabilityIndexPair,
Expand All @@ -29,14 +29,14 @@ pub fn deinit(self: *const Self) void {

pub fn sample(self: *Self, probability_distribution: []f32) usize {
if (self.temperature == 0) {
return vector.argmax(probability_distribution);
return math.argmax(probability_distribution);
}

for (probability_distribution) |*probability| {
probability.* /= self.temperature;
}

vector.softmax(probability_distribution);
math.softmax(probability_distribution);

if (self.top_p <= 0 or self.top_p >= 1) {
return self.sampleMultinomial(probability_distribution);
Expand Down
104 changes: 83 additions & 21 deletions src/tensor.zig
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const std = @import("std");
const vector = @import("./vector.zig");

pub fn Tensor(comptime n_dims: comptime_int) type {
comptime if (n_dims < 1) @compileError("n_dims < 1");
Expand All @@ -9,62 +8,125 @@ pub fn Tensor(comptime n_dims: comptime_int) type {

allocator: ?std.mem.Allocator,
sub_dims: [n_dims - 1]usize,
data: []f32,
values: []f32,

pub fn init(allocator: std.mem.Allocator, dims: [n_dims]usize) !Self {
const data_size = @reduce(.Mul, @as(@Vector(n_dims, usize), dims));
const n_values = @reduce(.Mul, @as(@Vector(n_dims, usize), dims));

return .{
.allocator = allocator,
.sub_dims = dims[1..].*,
.data = try allocator.alloc(f32, data_size),
.values = try allocator.alloc(f32, n_values),
};
}

pub fn deinit(self: *const Self) void {
if (self.allocator) |allocator| {
allocator.free(self.data);
allocator.free(self.values);
}
}

pub fn read(self: *const Self, file: std.fs.File) !void {
const buffer: [*]u8 = @ptrCast(self.data);
const values: [*]u8 = @ptrCast(self.values);

try file.reader().readNoEof(buffer[0 .. self.data.len * @sizeOf(f32)]);
try file.reader().readNoEof(values[0 .. self.values.len * @sizeOf(f32)]);
}

pub fn write(self: *const Self, file: std.fs.File) !void {
const buffer: [*]u8 = @ptrCast(self.data);
const values: [*]u8 = @ptrCast(self.values);

try file.writer().writeAll(buffer[0 .. self.data.len * @sizeOf(f32)]);
try file.writer().writeAll(values[0 .. self.values.len * @sizeOf(f32)]);
}

pub fn slice(self: *const Self, index: usize) Tensor(n_dims - 1) {
comptime if (n_dims < 2) @compileError("n_dims < 2");

const sub_data_size = @reduce(.Mul, @as(@Vector(n_dims - 1, usize), self.sub_dims));
const n_sub_values = @reduce(.Mul, @as(@Vector(n_dims - 1, usize), self.sub_dims));

return .{
.allocator = null,
.sub_dims = self.sub_dims[1..].*,
.data = self.data[index * sub_data_size ..][0..sub_data_size],
.values = self.values[index * n_sub_values ..][0..n_sub_values],
};
}

pub fn multiplyVector(self: *const Self, input: anytype, output: anytype) void {
comptime if (n_dims < 2) @compileError("n_dims < 2");
pub fn computeMatrixVectorMultiplication(
self: *const Self,
input: anytype,
output: anytype,
) void {
for (output.values, 0..) |*value, index| {
value.* = self.slice(index).computeScalarProduct(&input);
}
}

pub fn computeScalarProduct(self: *const Self, other: anytype) f32 {
if (self.values.len % 32 == 0) {
return _computeScalarProduct(32, self, other);
}

if (self.values.len % 16 == 0) {
return _computeScalarProduct(16, self, other);
}

if (self.values.len % 8 == 0) {
return _computeScalarProduct(8, self, other);
}

return _computeScalarProduct(4, self, other);
}

pub fn add(self: *const Self, other: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(self.values.len == other.values.len);

const sub_data_size = @reduce(.Mul, @as(@Vector(n_dims - 1, usize), self.sub_dims));
for (self.values, 0..) |*value, index| {
value.* += other.values[index];
}
}

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(self: *const Self, input: anytype, output: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(output.values.len == self.values.len);
std.debug.assert(output.values.len == input.values.len);

var rms_scaling_factor: f32 = 0;

for (input.values) |value| {
rms_scaling_factor += value * value;
}

std.debug.assert(input.data.len == sub_data_size);
std.debug.assert(output.data.len == self.data.len / sub_data_size);
rms_scaling_factor /= @floatFromInt(input.values.len);
rms_scaling_factor += 1e-5;
rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor);

for (output.data, 0..) |*value, index| {
value.* = vector.dot(
self.data[index * sub_data_size ..][0..sub_data_size],
input.data,
);
for (output.values, 0..) |*value, index| {
value.* = self.values[index] * rms_scaling_factor * input.values[index];
}
}
};
}

fn _computeScalarProduct(
comptime vector_size: comptime_int,
input_1: anytype,
input_2: anytype,
) f32 {
@setFloatMode(.Optimized);

std.debug.assert(input_1.values.len == input_2.values.len);

var output_values: @Vector(vector_size, f32) = @splat(0.0);
var index: usize = 0;

while (index <= input_1.values.len - vector_size) : (index += vector_size) {
output_values +=
@as(@Vector(vector_size, f32), input_1.values[index..][0..vector_size].*) *
@as(@Vector(vector_size, f32), input_2.values[index..][0..vector_size].*);
}

return @reduce(.Add, output_values);
}
Loading

0 comments on commit 84f4e7f

Please sign in to comment.