Skip to content

Commit

Permalink
Improve RoPE rotation block
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 16, 2023
1 parent f00893d commit 1bb8aa2
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,26 @@ pub fn decode(
var dim_i: usize = 0;

// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
// https://github.com/karpathy/llama2.c/issues/302#issue-1851956882
while (dim_i < config.dim) : (dim_i += 2) {
const q0 = run_state.q[dim_i];
const q1 = run_state.q[dim_i + 1];
const fcr = freq_cis_real_row[(dim_i % head_size) / 2];
const fci = freq_cis_imag_row[(dim_i % head_size) / 2];

// rotate q
const q0 = run_state.q[dim_i];
const q1 = run_state.q[dim_i + 1];

run_state.q[dim_i] = q0 * fcr - q1 * fci;
run_state.q[dim_i + 1] = q0 * fci + q1 * fcr;
}

dim_i = 0;
// rotate k
if (dim_i < kv_dim) {
const k0 = run_state.k[dim_i];
const k1 = run_state.k[dim_i + 1];

while (dim_i < kv_dim) : (dim_i += 2) {
const k0 = run_state.k[dim_i];
const k1 = run_state.k[dim_i + 1];
const fcr = freq_cis_real_row[(dim_i % head_size) / 2];
const fci = freq_cis_imag_row[(dim_i % head_size) / 2];

run_state.k[dim_i] = k0 * fcr - k1 * fci;
run_state.k[dim_i + 1] = k0 * fci + k1 * fcr;
run_state.k[dim_i] = k0 * fcr - k1 * fci;
run_state.k[dim_i + 1] = k0 * fci + k1 * fcr;
}
}

// save key,value at this time step (pos) to our kv cache
Expand Down

0 comments on commit 1bb8aa2

Please sign in to comment.