From 1bb8aa2b67e3a504134913d24750addbbff1500e Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Wed, 16 Aug 2023 15:34:30 +0200 Subject: [PATCH] Improve RoPE rotation block --- src/transformer.zig | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformer.zig b/src/transformer.zig index ff57722..27bb720 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -108,26 +108,26 @@ pub fn decode( var dim_i: usize = 0; // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head + // https://github.com/karpathy/llama2.c/issues/302#issue-1851956882 while (dim_i < config.dim) : (dim_i += 2) { - const q0 = run_state.q[dim_i]; - const q1 = run_state.q[dim_i + 1]; const fcr = freq_cis_real_row[(dim_i % head_size) / 2]; const fci = freq_cis_imag_row[(dim_i % head_size) / 2]; + // rotate q + const q0 = run_state.q[dim_i]; + const q1 = run_state.q[dim_i + 1]; + run_state.q[dim_i] = q0 * fcr - q1 * fci; run_state.q[dim_i + 1] = q0 * fci + q1 * fcr; - } - dim_i = 0; + // rotate k + if (dim_i < kv_dim) { + const k0 = run_state.k[dim_i]; + const k1 = run_state.k[dim_i + 1]; - while (dim_i < kv_dim) : (dim_i += 2) { - const k0 = run_state.k[dim_i]; - const k1 = run_state.k[dim_i + 1]; - const fcr = freq_cis_real_row[(dim_i % head_size) / 2]; - const fci = freq_cis_imag_row[(dim_i % head_size) / 2]; - - run_state.k[dim_i] = k0 * fcr - k1 * fci; - run_state.k[dim_i + 1] = k0 * fci + k1 * fcr; + run_state.k[dim_i] = k0 * fcr - k1 * fci; + run_state.k[dim_i + 1] = k0 * fci + k1 * fcr; + } } // save key,value at this time step (pos) to our kv cache