Skip to content

Commit

Permalink
Feed forward struct
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 21, 2023
1 parent 2b836a6 commit 1d426bd
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 84 deletions.
10 changes: 5 additions & 5 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ pub const Weights = struct {
value: []f32, // n_layers * dim * (n_kv_heads * head_size)
attention_output: []f32, // n_layers * (n_heads * head_size) * dim
rms_ffn_input: []f32, // n_layers * dim
ffn_input: []f32, // n_layers * dim * hidden_dim
ffn_hidden: []f32, // n_layers * hidden_dim * dim
ffn_residual: []f32, // n_layers * dim * hidden_dim
ffn_hidden: []f32, // w1; n_layers * dim * hidden_dim
ffn_output: []f32, // w2; n_layers * hidden_dim * dim
ffn_residual: []f32, // w3; n_layers * dim * hidden_dim
rms_final: []f32, // dim
freq_cis_real: []f32, // seq_len * head_size / 2
freq_cis_imag: []f32, // seq_len * head_size / 2
Expand Down Expand Up @@ -89,8 +89,8 @@ pub fn readFile(
.value = readFloatSlice(&weights_data, config.n_layers * config.dim * (config.n_kv_heads * head_size)),
.attention_output = readFloatSlice(&weights_data, config.n_layers * (config.n_heads * head_size) * config.dim),
.rms_ffn_input = readFloatSlice(&weights_data, config.n_layers * config.dim),
.ffn_input = readFloatSlice(&weights_data, config.n_layers * config.dim * config.hidden_dim),
.ffn_hidden = readFloatSlice(&weights_data, config.n_layers * config.hidden_dim * config.dim),
.ffn_hidden = readFloatSlice(&weights_data, config.n_layers * config.dim * config.hidden_dim),
.ffn_output = readFloatSlice(&weights_data, config.n_layers * config.hidden_dim * config.dim),
.ffn_residual = readFloatSlice(&weights_data, config.n_layers * config.dim * config.hidden_dim),
.rms_final = readFloatSlice(&weights_data, config.dim),

Expand Down
87 changes: 87 additions & 0 deletions src/feed_forward.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const utils = @import("utils.zig");

pub const FeedForward = struct {
const Self = @This();

input_buffer: []f32,
hidden_buffer: []f32,
residual_buffer: []f32,
output_buffer: []f32,

pub fn init(self: *Self, allocator: std.mem.Allocator, config: *const checkpoint.Config) !void {
self.input_buffer = try allocator.alloc(f32, config.dim);
self.hidden_buffer = try allocator.alloc(f32, config.hidden_dim);
self.residual_buffer = try allocator.alloc(f32, config.hidden_dim);
self.output_buffer = try allocator.alloc(f32, config.dim);
}

pub fn deinit(self: *const Self, allocator: std.mem.Allocator) void {
allocator.free(self.input_buffer);
allocator.free(self.hidden_buffer);
allocator.free(self.residual_buffer);
allocator.free(self.output_buffer);
}

pub fn forward(
self: *const Self,
weights: *const checkpoint.Weights,
layer: usize,
) !void {
@setFloatMode(.Optimized);

const input_buffer = self.input_buffer;
const hidden_buffer = self.hidden_buffer;
const residual_buffer = self.residual_buffer;
const output_buffer = self.output_buffer;

std.debug.assert(input_buffer.len == output_buffer.len);
std.debug.assert(hidden_buffer.len == residual_buffer.len);

const dim = input_buffer.len;
const hidden_dim = hidden_buffer.len;
const weights_size = dim * hidden_dim;
const weights_offset = layer * weights_size;
const hidden_weights = weights.ffn_hidden[weights_offset..][0..weights_size];
const residual_weights = weights.ffn_residual[weights_offset..][0..weights_size];
const output_weights = weights.ffn_output[weights_offset..][0..weights_size];

try matmul2(
.{ hidden_buffer, input_buffer, hidden_weights },
.{ residual_buffer, input_buffer, residual_weights },
dim >= 4096,
);

for (0..hidden_dim) |i| {
hidden_buffer[i] = silu(hidden_buffer[i]) * residual_buffer[i];
}

utils.matmul(output_buffer, hidden_buffer, output_weights);
}
};

// https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
inline fn silu(x: f32) f32 {
return x * sigmoid(x);
}

inline fn sigmoid(x: f32) f32 {
return 1 / (1 + @exp(-x));
}

fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 2) {
const thread_1 = try std.Thread.spawn(.{}, utils.matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, utils.matmul, args_2);

thread_1.join();
thread_2.join();
} else {
@call(.auto, utils.matmul, args_1);
@call(.auto, utils.matmul, args_2);
}
}
7 changes: 6 additions & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const cli = @import("cli.zig");
const FeedForward = @import("feed_forward.zig").FeedForward;
const tokenizer = @import("tokenizer.zig");
const transformer = @import("transformer.zig");
const utils = @import("utils.zig");
Expand Down Expand Up @@ -66,12 +67,16 @@ pub fn main() !void {
var total_decoding_time: i64 = 0;
var total_sampling_time: i64 = 0;

var feed_forward: FeedForward = undefined;

try feed_forward.init(allocator, &config);

// advance the state state machine
for (0..args.n_steps) |pos| {
start_time = std.time.milliTimestamp();

// forward the transformer to get logits for the next token
try transformer.decode(allocator, token, pos, config, &run_state, &weights);
try transformer.decode(allocator, token, pos, config, &run_state, &weights, &feed_forward);

if (pos == 0) {
first_decoding_time = std.time.milliTimestamp() - start_time;
Expand Down
99 changes: 21 additions & 78 deletions src/transformer.zig
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const FeedForward = @import("feed_forward.zig").FeedForward;
const utils = @import("utils.zig");

pub const RunState = struct {
hidden_state: []f32,

attention_ffn_input_buffer: []f32,
attention_ffn_output_buffer: []f32,

attention_input_buffer: []f32,
attention_output_buffer: []f32,
attention_scores: []f32,
query_buffer: []f32,
key_buffer: []f32,
value_buffer: []f32,
key_cache: []f32,
value_cache: []f32,

ffn_hidden_buffer: []f32,
ffn_residual_buffer: []f32,

logits: []f32,
};

Expand All @@ -32,8 +27,8 @@ pub fn allocRunState(
run_state.* = RunState{
.hidden_state = try allocator.alloc(f32, config.dim),

.attention_ffn_input_buffer = try allocator.alloc(f32, config.dim),
.attention_ffn_output_buffer = try allocator.alloc(f32, config.dim),
.attention_input_buffer = try allocator.alloc(f32, config.dim),
.attention_output_buffer = try allocator.alloc(f32, config.dim),

.attention_scores = try allocator.alloc(f32, config.n_heads * config.seq_len),
.query_buffer = try allocator.alloc(f32, config.dim),
Expand All @@ -42,9 +37,6 @@ pub fn allocRunState(
.key_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim),
.value_cache = try allocator.alloc(f32, config.n_layers * config.seq_len * kv_dim),

.ffn_hidden_buffer = try allocator.alloc(f32, config.hidden_dim),
.ffn_residual_buffer = try allocator.alloc(f32, config.hidden_dim),

.logits = try allocator.alloc(f32, config.vocab_size),
};
}
Expand All @@ -56,6 +48,7 @@ pub fn decode(
config: checkpoint.Config,
run_state: *RunState,
weights: *const checkpoint.Weights,
feed_forward: *FeedForward,
) !void {
@setFloatMode(.Optimized);

Expand All @@ -75,7 +68,7 @@ pub fn decode(
for (0..config.n_layers) |layer| {
// attention rmsnorm
utils.rmsnorm(
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
run_state.hidden_state,
weights.rms_attention_input[(layer * config.dim)..],
);
Expand All @@ -95,37 +88,37 @@ pub fn decode(

try pool.spawn(utils.matmul, .{
run_state.query_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.query[(layer * config.dim * config.dim)..],
});

try pool.spawn(utils.matmul, .{
run_state.key_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.key[(layer * config.dim * kv_dim)..],
});

try pool.spawn(utils.matmul, .{
run_state.value_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.value[(layer * config.dim * kv_dim)..],
});
} else {
utils.matmul(
run_state.query_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.query[(layer * config.dim * config.dim)..],
);

utils.matmul(
run_state.key_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.key[(layer * config.dim * kv_dim)..],
);

utils.matmul(
run_state.value_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_input_buffer,
weights.value[(layer * config.dim * kv_dim)..],
);
}
Expand Down Expand Up @@ -170,7 +163,7 @@ pub fn decode(

// weighted sum of the values, store back into intermediate_buffer
const intermediate_buffer =
run_state.attention_ffn_input_buffer[(head * head_size)..][0..head_size];
run_state.attention_input_buffer[(head * head_size)..][0..head_size];

@memset(intermediate_buffer, 0);

Expand All @@ -190,75 +183,25 @@ pub fn decode(

// final matmul to get the output of the attention
utils.matmul(
run_state.attention_ffn_output_buffer,
run_state.attention_ffn_input_buffer,
run_state.attention_output_buffer,
run_state.attention_input_buffer,
weights.attention_output[(layer * config.dim * config.dim)..],
);

// residual connection back into hidden_state
utils.accum(run_state.hidden_state, run_state.attention_ffn_output_buffer);
utils.accum(run_state.hidden_state, run_state.attention_output_buffer);

// ffn rmsnorm
utils.rmsnorm(
run_state.attention_ffn_input_buffer,
feed_forward.input_buffer,
run_state.hidden_state,
weights.rms_ffn_input[(layer * config.dim)..],
);

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
if (config.dim >= dim_multithreading_threshold) {
try pool.init(std.Thread.Pool.Options{
.allocator = allocator,
.n_jobs = @max(1, @min(2, std.Thread.getCpuCount() catch 1)),
});

defer pool.deinit();

try pool.spawn(utils.matmul, .{
run_state.ffn_hidden_buffer,
run_state.attention_ffn_input_buffer,
weights.ffn_input[(layer * config.dim * config.hidden_dim)..],
});

try pool.spawn(utils.matmul, .{
run_state.ffn_residual_buffer,
run_state.attention_ffn_input_buffer,
weights.ffn_residual[(layer * config.dim * config.hidden_dim)..],
});
} else {
utils.matmul(
run_state.ffn_hidden_buffer,
run_state.attention_ffn_input_buffer,
weights.ffn_input[(layer * config.dim * config.hidden_dim)..][0..(config.dim * config.hidden_dim)],
);

utils.matmul(
run_state.ffn_residual_buffer,
run_state.attention_ffn_input_buffer,
weights.ffn_residual[(layer * config.dim * config.hidden_dim)..][0..(config.dim * config.hidden_dim)],
);
}

// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (run_state.ffn_hidden_buffer) |*s| {
s.* *= 1 / (1 + std.math.exp(-s.*));
}

// elementwise multiply with w3(x)
for (0..config.hidden_dim) |i| {
run_state.ffn_hidden_buffer[i] *= run_state.ffn_residual_buffer[i];
}

// final matmul to get the output of the ffn
utils.matmul(
run_state.attention_ffn_output_buffer,
run_state.ffn_hidden_buffer,
weights.ffn_hidden[(layer * config.dim * config.hidden_dim)..][0..(config.dim * config.hidden_dim)],
);
try feed_forward.forward(weights, layer);

// residual connection
utils.accum(run_state.hidden_state, run_state.attention_ffn_output_buffer);
utils.accum(run_state.hidden_state, feed_forward.output_buffer);
}

// final rmsnorm
Expand All @@ -268,7 +211,7 @@ pub fn decode(
utils.matmul(run_state.logits, run_state.hidden_state, weights.classifier);
}

pub fn rope(
fn rope(
pos: usize,
head_size: usize,
kv_dim: usize,
Expand Down

0 comments on commit 1d426bd

Please sign in to comment.