From 946a46953319b899e609d8539bc746bd87a65f80 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 15 Sep 2024 00:44:26 -0400 Subject: [PATCH] Add basic paged attention for kv cache (no copy on write) --- .../tjake/jlama/model/AbstractModel.java | 27 +- .../jlama/model/CausalSelfAttention.java | 47 ++- .../github/tjake/jlama/model/LayerNorm.java | 9 +- .../com/github/tjake/jlama/model/RMSNorm.java | 7 +- .../tjake/jlama/model/TransformerBlock.java | 14 +- .../tjake/jlama/model/bert/BertModel.java | 6 +- .../tjake/jlama/model/gemma/GemmaModel.java | 1 + .../tjake/jlama/model/gpt2/GPT2Model.java | 1 + .../tjake/jlama/model/llama/LlamaModel.java | 1 + .../jlama/model/mixtral/MixtralModel.java | 1 + .../tjake/jlama/tensor/KvBufferCache.java | 318 +++++++++++++++--- .../operations/NaiveTensorOperations.java | 5 +- .../operations/PanamaTensorOperations.java | 191 ++++++----- .../tensor/operations/TensorOperations.java | 13 +- .../operations/NativeTensorOperations.java | 9 +- .../com/github/tjake/jlama/net/Worker.java | 4 +- .../github/tjake/jlama/model/TestModels.java | 12 +- .../tensor/operations/TestOperations.java | 22 ++ 18 files changed, 473 insertions(+), 215 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 38e6c1f..dfad236 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -44,7 +44,6 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; -import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -190,8 +189,8 @@ protected AbstractTensor maybeQuantize(AbstractTensor t) { return t2; } - protected AbstractTensor forward(int token_id, int pos, AbstractTensor kvbuf) { - return forward(token_id, pos, kvbuf, Optional.empty(), Optional.empty()); + protected AbstractTensor forward(int token_id, int pos, KvBufferCache.KvBuffer kvbuf) { + return forward(token_id, pos, kvbuf, Optional.empty()); } /** @@ -207,8 +206,7 @@ protected AbstractTensor forward(int token_id, int pos, AbstractTensor kvbuf) { public AbstractTensor forward( int token_id, int pos, - AbstractTensor kvbuf, - Optional>> normReducer, + KvBufferCache.KvBuffer kvbuf, Optional>> tensorReducer ) { AbstractTensor embedding = embedInput.inputTokenToEmbedding(token_id, pos); @@ -217,16 +215,15 @@ public AbstractTensor forward( debug("TOKEN POSITION", pos); for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) { - AbstractTensor kvlayer = kvbuf.slice(true, i); AbstractTensor ref = embedding; // reference so we can free - embedding = transformerBlocks[i].forward(embedding, pos, kvlayer, normReducer, tensorReducer); + embedding = transformerBlocks[i].forward(embedding, pos, kvbuf, tensorReducer); ref.close(); } return embedding; } - protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, AbstractTensor kvbuf) { + protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) { AbstractTensor last = null; for (int i = 0; i < token_ids.length; i++) { if (last != null) last.close(); @@ -236,13 +233,12 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, Abstrac return last; } - protected AbstractTensor batchForward(int[] token_ids, int startPos, AbstractTensor kvbuf) { + protected AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) { AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos); for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) { - AbstractTensor kvlayer = kvbuf.slice(true, i); AbstractTensor ref = embedding; // reference so we can free - embedding = transformerBlocks[i].forward(embedding, startPos, kvlayer, Optional.empty(), Optional.empty()); + embedding = transformerBlocks[i].forward(embedding, startPos, kvbuf, Optional.empty()); ref.close(); } @@ -303,11 +299,10 @@ public Response generate( encoded = Arrays.copyOfRange(encoded, 1, encoded.length); } - Preconditions.checkArgument(encoded.length < c.contextLength); + Preconditions.checkArgument(encoded.length < c.contextLength && encoded.length < ntokens, "Prompt exceeds max tokens"); - AbstractTensor kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window - Integer startPos = (Integer) kvmem.getMetadata(KvBufferCache.TOKEN_COUNT); // Number of tokens in the buffer - if (startPos == null) startPos = 0; + KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window + int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer logger.debug("Starting at token {} for session {}", startPos, sessionId); @@ -366,7 +361,7 @@ public Response generate( if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature); output.close(); - kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i); + kvmem.incrementContextPosition(); // Model may tell us it's done if (c.eosTokens.contains(next)) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index 944b715..1c2c876 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -20,6 +20,7 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.KvBufferCache; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import java.util.*; @@ -28,6 +29,7 @@ public class CausalSelfAttention { private final AbstractModel m; private final Config c; + private final int layerIndex; private final DistributedContext dctx; private final Optional queryAttnBias; private final Optional keyAttnBias; @@ -44,20 +46,20 @@ public class CausalSelfAttention { private final float attentionScale; private final int attentionLength; - private final boolean attentionQVSizeMismatch; private final AbstractTensor[] qkvResults; private final AbstractTensor[] qkvWeights; public CausalSelfAttention( AbstractModel m, + int layerIndex, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, AbstractTensor outputProjectionWeights ) { this( - m, + m, layerIndex, Optional.empty(), Optional.empty(), Optional.empty(), @@ -71,6 +73,7 @@ public CausalSelfAttention( public CausalSelfAttention( AbstractModel m, + int layerIndex, AbstractTensor queryAttnBias, AbstractTensor keyAttnBias, AbstractTensor valueAttnBias, @@ -82,6 +85,7 @@ public CausalSelfAttention( ) { this( m, + layerIndex, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), @@ -95,6 +99,7 @@ public CausalSelfAttention( public CausalSelfAttention( AbstractModel m, + int layerIndex, Optional queryAttnBias, Optional keyAttnBias, Optional valueAttnBias, @@ -105,6 +110,7 @@ public CausalSelfAttention( AbstractTensor outputProjectionWeights ) { this.m = m; + this.layerIndex = layerIndex; this.c = m.c; this.dctx = m.c.dctx(); this.queryAttnBias = queryAttnBias; @@ -118,7 +124,6 @@ public CausalSelfAttention( this.outputProjectionWeights = outputProjectionWeights; this.attentionLength = c.numberOfHeads * c.headSize; - this.attentionQVSizeMismatch = c.embeddingLength != attentionLength; this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize)); this.qkvResults = new AbstractTensor[3]; @@ -128,7 +133,7 @@ public CausalSelfAttention( public AbstractTensor forward( AbstractTensor input, int startPosition, - AbstractTensor kvMem, + KvBufferCache.KvBuffer kvMem, Optional>> tensorReducer ) { Preconditions.checkArgument(input.dims() == 2 && input.shape().last() == c.embeddingLength); @@ -214,11 +219,12 @@ public AbstractTensor forward( for (int position = startPosition, bi = 0; position < startPosition + batchSize; position++, bi++) { int finalPostion = position; - AbstractTensor kvp = kvMem.slice(true, 0); - AbstractTensor vvp = kvMem.slice(true, 1); + AbstractTensor key = kvMem.getKeyTensorForPosition(layerIndex, position); + AbstractTensor val = kvMem.getValTensorForPosition(layerIndex, position); + + AbstractTensor[] kvp = kvMem.getKeyTensorsUptoPosition(layerIndex, position); + AbstractTensor[] vvp = kvMem.getValTensorsUptoPosition(layerIndex, position); - AbstractTensor key = kvp.slice(position); - AbstractTensor val = vvp.slice(position); AbstractTensor tmpKey = tmpKeyBatch.slice(bi); AbstractTensor tmpVal = tmpValBatch.slice(bi); @@ -318,21 +324,34 @@ public AbstractTensor forward( // Attention VectorMath.pfor(dctx.headStart, dctx.headEnd, h -> { - try (AbstractTensor attn = m.makeDenseTensor(1, kvp.shape().first())) { - int xoffset = c.maybeMapToGroupHead(h) * c.headSize; - int yoffset = h * c.headSize; + int xoffset = c.maybeMapToGroupHead(h) * c.headSize; + int yoffset = h * c.headSize; - if (yoffset >= query.shape().last()) return; + if (yoffset >= query.shape().last()) return; + try (AbstractTensor attn = m.makeDenseTensor(1, finalPostion + 1)) { // compute attention scores by multiplying query and key for every position - TensorOperationsProvider.get().batchDotProduct(attn, query, kvp, yoffset, xoffset, c.headSize, 0, finalPostion + 1); + // Do this for each page + for (int i = 0; i < kvp.length; i++) { + int len = kvp[i].shape().first(); + int offset = i * len; + int size = i == kvp.length - 1 ? (finalPostion + 1) - offset : len; + TensorOperationsProvider.get().batchDotProduct(attn, query, kvp[i], yoffset, xoffset, c.headSize, offset, 0, size); + } + TensorOperationsProvider.get().scale(attentionScale, attn, 0, finalPostion + 1); // softmax the scores to get attention weights, from 0..pos inclusively VectorMath.softMax(attn, 0, finalPostion + 1); // apply adjusted attention weights to value vectors - TensorOperationsProvider.get().saxpy(attn, vvp, value, xoffset, yoffset, c.headSize, finalPostion + 1); + // do this for each page + for (int i = 0; i < vvp.length; i++) { + int len = vvp[i].shape().first(); + int offset = i * len; + int size = i == vvp.length - 1 ? (finalPostion + 1) - offset : len; + TensorOperationsProvider.get().saxpy(attn, vvp[i], value, xoffset, yoffset, c.headSize, offset, 0, size); + } } }); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/LayerNorm.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/LayerNorm.java index e5cee2e..2d7515e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/LayerNorm.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/LayerNorm.java @@ -34,21 +34,16 @@ public LayerNorm(AbstractModel m, AbstractTensor bias, AbstractTensor weights) { } public AbstractTensor forward(AbstractTensor input) { - return forward(input, Optional.empty()); - } - - public AbstractTensor forward(AbstractTensor input, Optional>> reducer) { Preconditions.checkArgument(input.shape().dims() == 2); int size = input.shape().last(); Preconditions.checkArgument(size == m.c.embeddingLength); - return forward(input, 0, m.c.embeddingLength, reducer); + return forward(input, 0, m.c.embeddingLength); } public AbstractTensor forward( AbstractTensor input, int offset, - int length, - Optional>> reducer + int length ) { int batchSize = input.shape().first(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java index 2669197..2859a52 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java @@ -16,9 +16,7 @@ package com.github.tjake.jlama.model; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.util.Pair; -import java.util.Optional; -import java.util.function.BiFunction; + public class RMSNorm extends LayerNorm { private final float weightAdjustment; @@ -36,8 +34,7 @@ public RMSNorm(AbstractModel m, AbstractTensor weights, float weightAdjustment) public AbstractTensor forward( AbstractTensor input, int offset, - int length, - Optional>> reducer + int length ) { int batchSize = input.shape().first(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java index cbc4e33..98c991e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java @@ -19,6 +19,7 @@ import com.github.tjake.jlama.model.functions.FeedForward; import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.KvBufferCache; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.Pair; import java.util.List; @@ -78,21 +79,20 @@ public TransformerBlock( this.postFFNorm = Optional.of(postFFNorm); } - public AbstractTensor forward(AbstractTensor embedding, int position, AbstractTensor kvBuffer) { - return forward(embedding, position, kvBuffer, Optional.empty(), Optional.empty()); + public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer) { + return forward(embedding, position, kvBuffer, Optional.empty()); } public AbstractTensor forward( AbstractTensor embedding, int position, - AbstractTensor kvBuffer, - Optional>> normReducer, + KvBufferCache.KvBuffer kvBuffer, Optional>> tensorReducer ) { debug("input_emb", embedding, layerIndex); - AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding, normReducer)).orElse(embedding); + AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding)).orElse(embedding); debug("ln_emb", lnemb, layerIndex); @@ -109,7 +109,7 @@ public AbstractTensor forward( debug("post_attn_res", postAttention, layerIndex); - AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention, normReducer); + AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention); debug("ln_emb2", lnemb2, layerIndex); @@ -131,7 +131,7 @@ public AbstractTensor forward( postAttention.close(); return postFFNorm.map(ln -> { - AbstractTensor lnout = ln.forward(postFF, normReducer); + AbstractTensor lnout = ln.forward(postFF); debug("ln_out", lnout, layerIndex); postFF.close(); return lnout; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index fe5aba0..1b06cc9 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -25,10 +25,12 @@ import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.KvBufferCache; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import java.util.Arrays; import java.util.Optional; +import java.util.UUID; public class BertModel extends AbstractModel { @@ -100,6 +102,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { AbstractTensor outputWeight = weights.load(prefix + "output.dense.weight"); CausalSelfAttention attention = new CausalSelfAttention( this, + i, keyBias, queryBias, valueBias, @@ -147,8 +150,7 @@ public float[] embed(String input) { Preconditions.checkArgument(encoded.length < c.contextLength); float[] outputEmbedding = new float[c.embeddingLength]; - try (AbstractTensor kvmem = makeDenseTensor(c.dctx().numberOfLayers, 2, encoded.length, c.embeddingLength)) { // 2 for key and value - + try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(UUID.randomUUID())) { int promptLength = encoded.length; float avgp = 1.0f / promptLength; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java index a9b55e1..8bc8bb5 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java @@ -85,6 +85,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { String prefix = base + "self_attn."; CausalSelfAttention attention = new CausalSelfAttention( this, + i, weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java index 8e96c26..2913537 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java @@ -75,6 +75,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { AbstractTensor[] attnWeights = weights.load(prefix + "c_attn.weight").transpose().split(3, 0); CausalSelfAttention attention = new CausalSelfAttention( this, + i, attnBias[0], attnBias[1], attnBias[2], diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index 85e8b55..1f43f75 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -97,6 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { String prefix = base + "self_attn."; CausalSelfAttention attention = new CausalSelfAttention( this, + i, weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java index 984febd..93faeba 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java @@ -74,6 +74,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { String prefix = base + "self_attn."; CausalSelfAttention attention = new CausalSelfAttention( this, + i, weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java index a0b9d60..3beaf14 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java @@ -20,6 +20,11 @@ import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.util.Pair; +import com.google.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; import java.io.IOError; import java.io.IOException; import java.io.RandomAccessFile; @@ -28,19 +33,21 @@ import java.nio.ShortBuffer; import java.nio.channels.FileChannel; import java.nio.file.Paths; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; /** * A cache for key-value buffers used in the model. * @see com.github.tjake.jlama.model.functions.Generator */ public class KvBufferCache { - - public static final String TOKEN_COUNT = "TOKEN_COUNT"; - - private final ConcurrentMap> kvBufferCache; + private static final Logger logger = LoggerFactory.getLogger(KvBufferCache.class); + private final ConcurrentMap kvBufferCache; private final AbstractModel model; public KvBufferCache(AbstractModel model) { @@ -48,62 +55,265 @@ public KvBufferCache(AbstractModel model) { this.model = model; } - public AbstractTensor getKvBuffer(UUID session) { - return kvBufferCache.computeIfAbsent(session, this::makeKvBuffer).right; + public KvBuffer getKvBuffer(UUID session) { + return kvBufferCache.computeIfAbsent(session, s -> new KvBuffer(s, 1 << 24)); //16MB per page } - private Pair makeKvBuffer(UUID session) { - TensorShape s; - Config c = model.getConfig(); - DistributedContext dctx = c.dctx(); - // FIXME: Max size should be configurable - int[] rawShape = new int[] { dctx.numberOfLayers, 2, Math.min(1024, c.contextLength), c.kvLength }; - - // Adjust the shape to be relative to the kv cache size (in case of GQA) - if (c.kvLength != dctx.kvSegmentLength) { - Pair kvOffset = Pair.of(dctx.kvSegmentStart, dctx.kvSegmentEnd); - s = TensorShape.sparseColumn(rawShape, kvOffset); - } else { - s = TensorShape.of(rawShape); - } - - // If we don't have a working directory, just use a FloatBufferTensor - if (model.getConfig().workingDirectory().isEmpty()) { - return Pair.of(null, AbstractTensor.make(model.getWorkingDType(), s)); - } - - // Otherwise, create a file-backed tensor - try { - RandomAccessFile raf = new RandomAccessFile( - Paths.get(model.getConfig().workingDirectory().get().toString(), session.toString()).toFile(), - "rw" - ); - long bytes = s.size() * model.getWorkingDType().size(); - raf.setLength(bytes); - - AbstractTensor t; - if (model.getWorkingDType() == DType.F32) { - FloatBuffer fb = raf.getChannel() - .map(FileChannel.MapMode.READ_WRITE, 0, bytes) - .order(ByteOrder.LITTLE_ENDIAN) - .asFloatBuffer(); - - t = new FloatBufferTensor(fb, s, true); - } else if (model.getWorkingDType() == DType.BF16) { - ShortBuffer sb = raf.getChannel() - .map(FileChannel.MapMode.READ_WRITE, 0, bytes) - .order(ByteOrder.LITTLE_ENDIAN) - .asShortBuffer(); - - t = new BFloat16BufferTensor("kvmem", sb, s, true); + class KvPageContext { + public final int numberOfLayerPages; + public final int numberOfContextPages; + private final int layersPerPage; + private final int contextLengthPerPage; + private final UUID session; + + public final TensorShape pageShape; + + public KvPageContext(UUID session, int numberOfLayerPages, int numberOfContextPages, int layersPerPage, int contextLengthPerPage) { + this.session = session; + this.numberOfLayerPages = numberOfLayerPages; + this.numberOfContextPages = numberOfContextPages; + this.layersPerPage = layersPerPage; + this.contextLengthPerPage = contextLengthPerPage; + + if (numberOfLayerPages < 1) + throw new IllegalArgumentException("totalPageCount must be >= 1"); + + if (numberOfContextPages < 1) + throw new IllegalArgumentException("numberOfContextPages must be >= 1"); + + if (layersPerPage < 1) + throw new IllegalArgumentException("layersPerPage must be >= 1"); + + if (contextLengthPerPage < 1) + throw new IllegalArgumentException("contextLengthPerPage must be >= 1"); + + TensorShape s; + Config c = model.getConfig(); + DistributedContext dctx = c.dctx(); + int[] rawShape = new int[] {layersPerPage, 2, contextLengthPerPage, c.kvLength}; + + // Adjust the shape to be relative to the kv cache size (in case of GQA) + if (c.kvLength != dctx.kvSegmentLength) { + Pair kvOffset = Pair.of(dctx.kvSegmentStart, dctx.kvSegmentEnd); + s = TensorShape.sparseColumn(rawShape, kvOffset); } else { - throw new UnsupportedOperationException("Only F32/BF16 is supported for now"); + s = TensorShape.of(rawShape); } - return Pair.of(raf, t); + this.pageShape = s; + } + } + + /** + * A Page of a key-value buffer. + * Rather than allocating one giant buffer for the entire key-value buffer, we allocate slices of the buffer + * as needed. This allows us to keep the memory usage low, and also allows us to allocate very large contexts. + */ + class KvBufferPage implements AutoCloseable { + private final AbstractTensor tensor; + + private final KvPageContext pageCtx; + private final String pageId; + + private final RandomAccessFile raf; + + KvBufferPage(KvPageContext pageCtx, String pageId) { + this.pageCtx = pageCtx; + this.pageId = pageId + ; + + if (model.getConfig().workingDirectory().isEmpty()) { + this.raf = null; + this.tensor = AbstractTensor.make(model.getWorkingDType(), pageCtx.pageShape); + } else { + try { + raf = new RandomAccessFile( + Paths.get(model.getConfig().workingDirectory().get().toString(), pageCtx.session.toString(), pageId).toFile(), + "rw"); + long bytes = pageCtx.pageShape.size() * model.getWorkingDType().size(); + raf.setLength(bytes); + + AbstractTensor t; + if (model.getWorkingDType() == DType.F32) { + FloatBuffer fb = raf.getChannel() + .map(FileChannel.MapMode.READ_WRITE, 0, bytes) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + + t = new FloatBufferTensor(fb, pageCtx.pageShape, true); + } else if (model.getWorkingDType() == DType.BF16) { + ShortBuffer sb = raf.getChannel() + .map(FileChannel.MapMode.READ_WRITE, 0, bytes) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + + t = new BFloat16BufferTensor("kvmem", sb, pageCtx.pageShape, true); + } else { + throw new UnsupportedOperationException("Only F32/BF16 is supported for now"); + } + + this.tensor = t; + + } catch (IOException e) { + throw new IOError(e); + } + } + } + + public AbstractTensor getTensor() { + return tensor; + } + + @Override + public void close() throws IOException { + if (raf != null) { + raf.close(); + } + } + } + + public class KvBuffer implements AutoCloseable { + private UUID session; + private final AtomicInteger currentContextPosition = new AtomicInteger(0); + private final KvBufferPage[][] pages; + + private final KvPageContext pageContext; + + KvBuffer(UUID session, int maxPageSizeInBytes) { + this.session = session; + this.pageContext = computePageSize(maxPageSizeInBytes); + this.pages = new KvBufferPage[pageContext.numberOfLayerPages][pageContext.numberOfContextPages]; + } + + public int getCurrentContextPosition() { + return currentContextPosition.get(); + } + + public void setCurrentContextPosition(int position) { + currentContextPosition.set(position); + } + + public void incrementContextPosition() { + currentContextPosition.incrementAndGet(); + } + + public KvPageContext computePageSize(long maxPageSizeInBytes) { + Config c = model.getConfig(); + DType workingDType = model.getWorkingDType(); + long s = 2L * workingDType.size() * c.dctx().kvSegmentLength; // Size per layer per context + + Preconditions.checkArgument(maxPageSizeInBytes > s, "maxPageSizeInBytes must be greater than the size of a single layer"); + + int N = c.dctx().numberOfLayers; + int C = c.contextLength; + + int optimalLayersPerPage = 1; + int optimalContextLengthPerPage = 1; + long maxProduct = 0; + + // Try partitioning by layers + for (int x = N; x >= 1; x--) { + long y = maxPageSizeInBytes / (x * s); + + if (y >= 1 && y <= C) { + long product = x * y; + + if (product > maxProduct) { + optimalLayersPerPage = x; + optimalContextLengthPerPage = (int) y; + maxProduct = product; + } + // Break if product starts decreasing + if (product < maxProduct) { + break; + } + } + } + + // Try partitioning by context length + for (int y = C; y >= 1; y--) { + long x = maxPageSizeInBytes / (y * s); + + if (x >= 1 && x <= N) { + long product = x * y; + + if (product > maxProduct) { + optimalLayersPerPage = (int) x; + optimalContextLengthPerPage = y; + maxProduct = product; + } + if (product < maxProduct) { + break; + } + } + } + + // Calculate the number of pages needed + int numberOfLayerPages = (int) Math.ceil((double) N / optimalLayersPerPage); + int numberOfContextPages = (int) Math.ceil((double) C / optimalContextLengthPerPage); + + // Calculate the size of each page + long pageSize = optimalLayersPerPage * optimalContextLengthPerPage * s; + + if (pageSize > maxPageSizeInBytes) { + throw new IllegalArgumentException("Calculation error: pageSize > maxPageSizeInBytes: " + pageSize + " > " + maxPageSizeInBytes); + } + + return new KvPageContext(session, numberOfLayerPages, numberOfContextPages, optimalLayersPerPage, optimalContextLengthPerPage); + } + + @Override + public void close() { + + } + + public AbstractTensor getKeyTensorForPosition(int layerIndex, int position) { + return getTensorForPosition(layerIndex, position, 0); + } + + public AbstractTensor getValTensorForPosition(int layerIndex, int position) { + return getTensorForPosition(layerIndex, position, 1); + } + + private AbstractTensor getTensorForPosition(int layerIndex, int position, int index) { + // Calculate page indices and relative indices + int layerPageIndex = layerIndex / pageContext.layersPerPage; + int contextPageIndex = position / pageContext.contextLengthPerPage; + int relativeLayerIndex = layerPageIndex % pageContext.layersPerPage; + int relativeContextIndex = position % pageContext.contextLengthPerPage; + + KvBufferPage page = pages[layerPageIndex][contextPageIndex]; + if (page == null) { + page = new KvBufferPage(pageContext, "L" + layerPageIndex + "C" + contextPageIndex); + pages[layerPageIndex][contextPageIndex] = page; + } + + return page.getTensor().slice(true, relativeLayerIndex, index, relativeContextIndex); + } + + public AbstractTensor[] getKeyTensorsUptoPosition(int layerIndex, int upperBound) { + return getTensorsUptoPosition(layerIndex, 0, upperBound); + } + + public AbstractTensor[] getValTensorsUptoPosition(int layerIndex, int upperBound) { + return getTensorsUptoPosition(layerIndex, 1, upperBound); + } + + private AbstractTensor[] getTensorsUptoPosition(int layerIndex, int index, int upperBound) { + int layerPageIndex = layerIndex / pageContext.layersPerPage; + int contextPageIndex = upperBound / pageContext.contextLengthPerPage; + int relativeLayerIndex = layerIndex % pageContext.layersPerPage; + + KvBufferPage[] layerPages = pages[layerPageIndex]; + + AbstractTensor[] tensors = new AbstractTensor[contextPageIndex + 1]; + + for (int i = 0; i <= contextPageIndex; i++) { + KvBufferPage page = layerPages[i]; + tensors[i] = page.getTensor().slice(true, relativeLayerIndex, index); + } - } catch (IOException e) { - throw new IOError(e); + return tensors; } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java index 3043c00..cec2d24 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java @@ -78,6 +78,7 @@ public void batchDotProduct( int aColumnOffset, int bColumnOffset, int columnLength, + int rRowOffset, int bRowOffset, int rowChunkSize ) { @@ -86,9 +87,9 @@ public void batchDotProduct( int bRowLimit = bRowOffset + rowChunkSize; for (int i = 0; i < a.shape().first(); i++) { - for (int j = bRowOffset; j < bRowLimit; j++) { + for (int j = bRowOffset, r = rRowOffset; j < bRowLimit; j++, r++) { float d = dotProduct(a.slice(i), b.slice(j), aColumnOffset, bColumnOffset, columnLength); - result.set(d, i, j); + result.set(d, i, r); } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 644ed66..e787198 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -99,11 +99,13 @@ public void batchDotProduct( int aColumnOffset, int bColumnOffset, int columnLength, + int rOffset, int bRowOffset, int rowChunkSize ) { Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2); Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M"); + Preconditions.checkArgument(rOffset >= bRowOffset, "Result offset must be >= b row offset"); // Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N"); // This check breaks for GQA // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K" + a.shape() + " " + b.shape() + " " @@ -115,26 +117,26 @@ public void batchDotProduct( Gemmer gemm = switch (a.dType()) { case F32 -> switch (b.dType()) { - case F32 -> new GemmerF32(K, a, b, result, aColumnOffset, bColumnOffset); - case BF16 -> new GemmerF32BF16(K, a, b, result, aColumnOffset, bColumnOffset); + case F32 -> new GemmerF32(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); + case BF16 -> new GemmerF32BF16(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); case Q4 -> switch (vectorType) { - case AVX_256 -> new GemmerF32Q4_256(K, a, b, result, aColumnOffset, bColumnOffset); - case AVX_512 -> new GemmerF32Q4_512(K, a, b, result, aColumnOffset, bColumnOffset); + case AVX_256 -> new GemmerF32Q4_256(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); + case AVX_512 -> new GemmerF32Q4_512(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); default -> throw new UnsupportedOperationException(vectorType.name()); }; default -> throw new UnsupportedOperationException(b.dType().name()); }; case I8 -> switch (b.dType()) { case Q4 -> switch (vectorType) { - case AVX_256 -> new GemmerI8Q4_256(K, a, b, result, aColumnOffset, bColumnOffset); - case AVX_512 -> new GemmerI8Q4_512(K, a, b, result, aColumnOffset, bColumnOffset); - case ARM_128 -> new GemmerI8Q4_arm(K, a, b, result, aColumnOffset, bColumnOffset); + case AVX_256 -> new GemmerI8Q4_256(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); + case AVX_512 -> new GemmerI8Q4_512(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); + case ARM_128 -> new GemmerI8Q4_arm(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); default -> throw new UnsupportedOperationException(vectorType.name()); }; default -> throw new UnsupportedOperationException(b.dType().name()); }; case BF16 -> switch (b.dType()) { - case BF16 -> new GemmerBF16(K, a, b, result, aColumnOffset, bColumnOffset); + case BF16 -> new GemmerBF16(K, a, b, result, aColumnOffset, bColumnOffset, rOffset); default -> throw new UnsupportedOperationException(b.dType().name()); }; default -> throw new UnsupportedOperationException(a.dType().name() + " " + b.dType().name()); @@ -152,8 +154,8 @@ private class GemmerF32Q4_256 extends Gemmer { final Q4ByteBufferTensor b; final FloatBufferTensor a; - GemmerF32Q4_256(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth) { - super(k, ta, tb, c, ith, nth); + GemmerF32Q4_256(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, ta, tb, c, ith, nth, rOffset); this.a = (FloatBufferTensor) ta; this.b = (Q4ByteBufferTensor) tb; @@ -244,7 +246,7 @@ protected BiIntConsumer initMatmul1x1() { acc = af3.fma(high1, acc); } - c.set(acc.reduceLanes(VectorOperators.ADD), i, j); + c.set(acc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } @@ -416,8 +418,8 @@ protected BiIntConsumer initMatmul1x4() { } */ } - c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0); - c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1); + c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0 + rOffset); + c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1 + rOffset); // c.set(acc2.reduceLanes(VectorOperators.ADD), i, j + 2); // c.set(acc3.reduceLanes(VectorOperators.ADD), i, j + 3); }; @@ -433,8 +435,8 @@ private class GemmerF32Q4_512 extends Gemmer { final Q4ByteBufferTensor b; final FloatBufferTensor a; - GemmerF32Q4_512(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth) { - super(k, ta, tb, c, ith, nth); + GemmerF32Q4_512(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, ta, tb, c, ith, nth, rOffset); this.a = (FloatBufferTensor) ta; this.b = (Q4ByteBufferTensor) tb; @@ -507,7 +509,7 @@ protected BiIntConsumer initMatmul1x1() { acc = af1.fma(high0, acc); } - c.set(acc.reduceLanes(VectorOperators.ADD), i, j); + c.set(acc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } @@ -567,10 +569,10 @@ protected final BiIntConsumer initMatmul4x1() { acc3 = af31.fma(high0, acc3); } - c.set(acc0.reduceLanes(VectorOperators.ADD), i + 0, j); - c.set(acc1.reduceLanes(VectorOperators.ADD), i + 1, j); - c.set(acc2.reduceLanes(VectorOperators.ADD), i + 2, j); - c.set(acc3.reduceLanes(VectorOperators.ADD), i + 3, j); + c.set(acc0.reduceLanes(VectorOperators.ADD), i + 0, j + rOffset); + c.set(acc1.reduceLanes(VectorOperators.ADD), i + 1, j + rOffset); + c.set(acc2.reduceLanes(VectorOperators.ADD), i + 2, j + rOffset); + c.set(acc3.reduceLanes(VectorOperators.ADD), i + 3, j + rOffset); }; } @@ -677,10 +679,10 @@ protected BiIntConsumer initMatmul1x4() { } } - c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0); - c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1); - c.set(acc2.reduceLanes(VectorOperators.ADD), i, j + 2); - c.set(acc3.reduceLanes(VectorOperators.ADD), i, j + 3); + c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0 + rOffset); + c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1 + rOffset); + c.set(acc2.reduceLanes(VectorOperators.ADD), i, j + 2 + rOffset); + c.set(acc3.reduceLanes(VectorOperators.ADD), i, j + 3 + rOffset); }; } } @@ -694,8 +696,8 @@ private class GemmerI8Q4_arm extends Gemmer { final Q8ByteBufferTensor a; final Q4ByteBufferTensor b; - GemmerI8Q4_arm(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset) { - super(k, ta, tb, c, aColumnOffset, bColumnOffset); + GemmerI8Q4_arm(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) { + super(k, ta, tb, c, aColumnOffset, bColumnOffset, rOffset); this.a = (Q8ByteBufferTensor) ta; this.b = (Q4ByteBufferTensor) tb; @@ -800,7 +802,7 @@ protected BiIntConsumer initMatmul1x1() { } } - c.set(acc.reduceLanes(VectorOperators.ADD), i, j); + c.set(acc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } } @@ -814,8 +816,8 @@ private class GemmerI8Q4_256 extends Gemmer { final Q8ByteBufferTensor a; final Q4ByteBufferTensor b; - GemmerI8Q4_256(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset) { - super(k, ta, tb, c, aColumnOffset, bColumnOffset); + GemmerI8Q4_256(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) { + super(k, ta, tb, c, aColumnOffset, bColumnOffset, rOffset); this.a = (Q8ByteBufferTensor) ta; this.b = (Q4ByteBufferTensor) tb; @@ -911,7 +913,7 @@ protected BiIntConsumer initMatmul1x1() { } } - c.set(acc.reduceLanes(VectorOperators.ADD), i, j); + c.set(acc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } } @@ -924,8 +926,8 @@ private class GemmerI8Q4_512 extends Gemmer { final Q8ByteBufferTensor a; final Q4ByteBufferTensor b; - GemmerI8Q4_512(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth) { - super(k, ta, tb, c, ith, nth); + GemmerI8Q4_512(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, ta, tb, c, ith, nth, rOffset); this.a = (Q8ByteBufferTensor) ta; this.b = (Q4ByteBufferTensor) tb; @@ -996,7 +998,7 @@ protected BiIntConsumer initMatmul1x1() { acc = scale.fma(r0, acc); } - c.set(acc.reduceLanes(VectorOperators.ADD), i, j); + c.set(acc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } @@ -1099,10 +1101,10 @@ protected BiIntConsumer initMatmul1x4() { float r2 = acc2.reduceLanes(VectorOperators.ADD); float r3 = acc3.reduceLanes(VectorOperators.ADD); - c.set(r0, i, j + 0); - c.set(r1, i, j + 1); - c.set(r2, i, j + 2); - c.set(r3, i, j + 3); + c.set(r0, i, j + 0 + rOffset); + c.set(r1, i, j + 1 + rOffset); + c.set(r2, i, j + 2 + rOffset); + c.set(r3, i, j + 3 + rOffset); }; } @@ -1186,10 +1188,10 @@ protected BiIntConsumer initMatmul3x4() { float r10 = acc10.reduceLanes(VectorOperators.ADD); float r11 = acc11.reduceLanes(VectorOperators.ADD); - c.set(r00, i + 0, j + 0); - c.set(r01, i + 0, j + 1); - c.set(r10, i + 1, j + 0); - c.set(r11, i + 1, j + 1); + c.set(r00, i + 0, j + 0 + rOffset); + c.set(r01, i + 0, j + 1 + rOffset); + c.set(r10, i + 1, j + 0 + rOffset); + c.set(r11, i + 1, j + 1 + rOffset); }; } } @@ -1201,8 +1203,8 @@ private class GemmerF32 extends Gemmer { final BiIntConsumer matmul3x4; final BiIntConsumer matmul4x1; - GemmerF32(int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int ith, int nth) { - super(k, a, b, c, ith, nth); + GemmerF32(int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, a, b, c, ith, nth, rOffset); this.matmul1x1 = initMatmul1x1(); this.matmul1x4 = initMatmul1x4(); @@ -1248,7 +1250,7 @@ protected BiIntConsumer initMatmul1x1() { FloatVector vb = b.getVector(FloatVector.SPECIES_PREFERRED, j, boffset).reinterpretAsFloats(); vc = va.fma(vb, vc); } - c.set(vc.reduceLanes(VectorOperators.ADD), i, j); + c.set(vc.reduceLanes(VectorOperators.ADD), i, j + rOffset); }; } @@ -1277,10 +1279,10 @@ protected BiIntConsumer initMatmul1x4() { vc3 = va.fma(vb3, vc3); } - c.set(vc0.reduceLanes(VectorOperators.ADD), i, j + 0); - c.set(vc1.reduceLanes(VectorOperators.ADD), i, j + 1); - c.set(vc2.reduceLanes(VectorOperators.ADD), i, j + 2); - c.set(vc3.reduceLanes(VectorOperators.ADD), i, j + 3); + c.set(vc0.reduceLanes(VectorOperators.ADD), i, j + 0 + rOffset); + c.set(vc1.reduceLanes(VectorOperators.ADD), i, j + 1 + rOffset); + c.set(vc2.reduceLanes(VectorOperators.ADD), i, j + 2 + rOffset); + c.set(vc3.reduceLanes(VectorOperators.ADD), i, j + 3 + rOffset); }; } @@ -1330,20 +1332,20 @@ protected BiIntConsumer initMatmul3x4() { vc23 = va2.fma(vb3, vc23); } - c.set(vc00.reduceLanes(VectorOperators.ADD), i + 0, j + 0); - c.set(vc01.reduceLanes(VectorOperators.ADD), i + 0, j + 1); - c.set(vc02.reduceLanes(VectorOperators.ADD), i + 0, j + 2); - c.set(vc03.reduceLanes(VectorOperators.ADD), i + 0, j + 3); + c.set(vc00.reduceLanes(VectorOperators.ADD), i + 0, j + 0 + rOffset); + c.set(vc01.reduceLanes(VectorOperators.ADD), i + 0, j + 1 + rOffset); + c.set(vc02.reduceLanes(VectorOperators.ADD), i + 0, j + 2 + rOffset); + c.set(vc03.reduceLanes(VectorOperators.ADD), i + 0, j + 3 + rOffset); - c.set(vc10.reduceLanes(VectorOperators.ADD), i + 1, j + 0); - c.set(vc11.reduceLanes(VectorOperators.ADD), i + 1, j + 1); - c.set(vc12.reduceLanes(VectorOperators.ADD), i + 1, j + 2); - c.set(vc13.reduceLanes(VectorOperators.ADD), i + 1, j + 3); + c.set(vc10.reduceLanes(VectorOperators.ADD), i + 1, j + 0 + rOffset); + c.set(vc11.reduceLanes(VectorOperators.ADD), i + 1, j + 1 + rOffset); + c.set(vc12.reduceLanes(VectorOperators.ADD), i + 1, j + 2 + rOffset); + c.set(vc13.reduceLanes(VectorOperators.ADD), i + 1, j + 3 + rOffset); - c.set(vc20.reduceLanes(VectorOperators.ADD), i + 2, j + 0); - c.set(vc21.reduceLanes(VectorOperators.ADD), i + 2, j + 1); - c.set(vc22.reduceLanes(VectorOperators.ADD), i + 2, j + 2); - c.set(vc23.reduceLanes(VectorOperators.ADD), i + 2, j + 3); + c.set(vc20.reduceLanes(VectorOperators.ADD), i + 2, j + 0 + rOffset); + c.set(vc21.reduceLanes(VectorOperators.ADD), i + 2, j + 1 + rOffset); + c.set(vc22.reduceLanes(VectorOperators.ADD), i + 2, j + 2 + rOffset); + c.set(vc23.reduceLanes(VectorOperators.ADD), i + 2, j + 3 + rOffset); }; } @@ -1373,10 +1375,10 @@ protected BiIntConsumer initMatmul4x1() { vc3 = va3.fma(vb0, vc3); } - c.set(vc0.reduceLanes(VectorOperators.ADD), i + 0, j); - c.set(vc1.reduceLanes(VectorOperators.ADD), i + 1, j); - c.set(vc2.reduceLanes(VectorOperators.ADD), i + 2, j); - c.set(vc3.reduceLanes(VectorOperators.ADD), i + 3, j); + c.set(vc0.reduceLanes(VectorOperators.ADD), i + 0, j + rOffset); + c.set(vc1.reduceLanes(VectorOperators.ADD), i + 1, j + rOffset); + c.set(vc2.reduceLanes(VectorOperators.ADD), i + 2, j + rOffset); + c.set(vc3.reduceLanes(VectorOperators.ADD), i + 3, j + rOffset); }; } } @@ -1391,8 +1393,8 @@ private class GemmerBF16 extends Gemmer { final BFloat16BufferTensor a; final BFloat16BufferTensor b; - GemmerBF16(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth) { - super(k, ta, tb, c, ith, nth); + GemmerBF16(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, ta, tb, c, ith, nth, rOffset); this.matmul1x1 = initMatmul1x1(); /*this.matmul1x4 = initMatmul1x4(); @@ -1458,7 +1460,7 @@ protected BiIntConsumer initMatmul1x1() { vc = va1.fma(vb1, vc); } float res = vc.reduceLanes(VectorOperators.ADD); - c.set(res, i, j); + c.set(res, i, j + rOffset); }; } @@ -1624,8 +1626,8 @@ private class GemmerF32BF16 extends Gemmer { final FloatBufferTensor a; final BFloat16BufferTensor b; - GemmerF32BF16(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth) { - super(k, ta, tb, c, ith, nth); + GemmerF32BF16(int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) { + super(k, ta, tb, c, ith, nth, rOffset); this.matmul1x1 = initMatmul1x1(); /*this.matmul1x4 = initMatmul1x4(); @@ -1684,7 +1686,7 @@ protected BiIntConsumer initMatmul1x1() { vc = va1.fma(vb1, vc); } float res = vc.reduceLanes(VectorOperators.ADD); - c.set(res, i, j); + c.set(res, i, j + rOffset); }; } } @@ -1696,15 +1698,17 @@ private abstract class Gemmer { final AbstractTensor c; final int aColumnOffset; final int bColumnOffset; + final int rOffset; // The id of each thread is called ith and the number of threads is called nth. - Gemmer(int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int aColumnOffset, int bColumnOffset) { + Gemmer(int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) { this.k = k; this.a = a; this.b = b; this.c = c; this.aColumnOffset = aColumnOffset; this.bColumnOffset = bColumnOffset; + this.rOffset = rOffset; } void matmul(int m0, int m, int n0, int n) { @@ -2545,20 +2549,20 @@ void saxpyF32(float alpha, FloatBufferTensor x, FloatBufferTensor y, int xoffset } @Override - public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) { + public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) { Preconditions.checkArgument(limit % 2 == 0); switch (x.dType()) { case F32: - saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit, batchSize); + saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit, aOffset, xOffset, batchSize); break; case BF16: switch (y.dType()) { case F32: - saxpyBF16F32(alpha, x, y, xoffset, yoffset, limit, batchSize); + saxpyBF16F32(alpha, x, y, xoffset, yoffset, limit, aOffset, xOffset, batchSize); break; case BF16: - saxpyBF16(alpha, x, y, xoffset, yoffset, limit, batchSize); + saxpyBF16(alpha, x, y, xoffset, yoffset, limit, aOffset, xOffset, batchSize); break; default: throw new UnsupportedOperationException(); @@ -2576,16 +2580,19 @@ public void saxpyF32( int xoffset, int yoffset, int limit, + int aOffset, + int xOffset, int batchSize ) { - int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(limit); // Use Nearest multiple of 4 - int batchLimit = batchSize - (batchSize % 4); - int a = 0; + int aLimit = batchSize - (batchSize % 4); + int a = aOffset; + int xi = xOffset; + aLimit += aOffset; - for (; a < batchLimit; a += 4) { + for (; a < aLimit; a += 4, xi += 4) { int xo = xoffset; int yo = yoffset; @@ -2596,10 +2603,10 @@ public void saxpyF32( for (; xo < (xoffset + upperBound) && yo < (yoffset + upperBound); xo += FloatVector.SPECIES_PREFERRED.length(), yo += FloatVector.SPECIES_PREFERRED.length()) { - FloatVector x0 = x.getVector(FloatVector.SPECIES_PREFERRED, a + 0, xo); - FloatVector x1 = x.getVector(FloatVector.SPECIES_PREFERRED, a + 1, xo); - FloatVector x2 = x.getVector(FloatVector.SPECIES_PREFERRED, a + 2, xo); - FloatVector x3 = x.getVector(FloatVector.SPECIES_PREFERRED, a + 3, xo); + FloatVector x0 = x.getVector(FloatVector.SPECIES_PREFERRED, xi + 0, xo); + FloatVector x1 = x.getVector(FloatVector.SPECIES_PREFERRED, xi + 1, xo); + FloatVector x2 = x.getVector(FloatVector.SPECIES_PREFERRED, xi + 2, xo); + FloatVector x3 = x.getVector(FloatVector.SPECIES_PREFERRED, xi + 3, xo); FloatVector vy = y.getVector(FloatVector.SPECIES_PREFERRED, 0, yo); @@ -2613,18 +2620,19 @@ public void saxpyF32( } // tail - for (; a < batchSize; a++) { - saxpyF32(alpha.get(0, a), (FloatBufferTensor) x.slice(a), y, xoffset, yoffset, limit); + for (; a < aOffset + batchSize; a++, xi++) { + saxpyF32(alpha.get(0, a), (FloatBufferTensor) x.slice(xi), y, xoffset, yoffset, limit); } } - public void saxpyBF16(AbstractTensor alpha, AbstractTensor xt, AbstractTensor yt, int xoffset, int yoffset, int limit, int batchSize) { + public void saxpyBF16(AbstractTensor alpha, AbstractTensor xt, AbstractTensor yt, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) { BFloat16BufferTensor x = (BFloat16BufferTensor) xt; BFloat16BufferTensor y = (BFloat16BufferTensor) yt; - for (int a = 0; a < batchSize; a++) { - saxpyBF16(alpha.get(0, a), (BFloat16BufferTensor) x.slice(a), y, xoffset, yoffset, limit); + int batchLimit = aOffset + batchSize; + for (int a = aOffset, xi = xOffset; a < batchLimit; a++, xi++) { + saxpyBF16(alpha.get(0, a), (BFloat16BufferTensor) x.slice(xi), y, xoffset, yoffset, limit); } } @@ -2635,14 +2643,17 @@ public void saxpyBF16F32( int xoffset, int yoffset, int limit, + int aOffset, + int xOffset, int batchSize ) { BFloat16BufferTensor x = (BFloat16BufferTensor) xt; FloatBufferTensor y = (FloatBufferTensor) yt; - for (int a = 0; a < batchSize; a++) { - saxpyBF16F32(alpha.get(0, a), (BFloat16BufferTensor) x.slice(a), y, xoffset, yoffset, limit); + int batchLimit = aOffset + batchSize; + for (int a = aOffset, xi = xOffset; a < batchLimit; a++, xi++) { + saxpyBF16F32(alpha.get(0, a), (BFloat16BufferTensor) x.slice(xi), y, xoffset, yoffset, limit); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java index 846e699..8442015 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java @@ -49,7 +49,7 @@ default void batchDotProduct( int bColumnOffset, int columnLimit ) { - batchDotProduct(result, a, b, aColumnOffset, bColumnOffset, columnLimit, 0, b.shape().first()); + batchDotProduct(result, a, b, aColumnOffset, bColumnOffset, columnLimit, 0, 0, b.shape().first()); } void batchDotProduct( @@ -59,6 +59,7 @@ void batchDotProduct( int aColumnOffset, int bColumnOffset, int columnLimit, + int rRowOffset, int bRowOffset, int rowChunkSize ); @@ -72,7 +73,7 @@ default void dotProductChunk( int rowOffset, int rowChunkSize ) { - batchDotProduct(result, a, b, columnOffset, columnOffset, columnLimit, rowOffset, rowChunkSize); + batchDotProduct(result, a, b, columnOffset, columnOffset, columnLimit, 0, rowOffset, rowChunkSize); } default void dotProductBatchChunk( @@ -108,11 +109,11 @@ default void dotProductBatchChunk( /** * The value computed is Y[i] = (alpha[j] * X[j, i]) + Y[i] */ - default void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) { + default void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int aOffset, int xRowOffset, int batchSize) { Preconditions.checkArgument(alpha.shape().last() == x.shape().first() && y.shape().first() == 1); - - for (int i = 0; i < batchSize; i++) { - saxpy(alpha.get(0, i), x.slice(i), y, xoffset, yoffset, limit); + int batchLimit = xRowOffset + batchSize; + for (int xi = xRowOffset; xi < batchLimit; xi++) { + saxpy(alpha.get(0, aOffset++), x.slice(xi), y, xoffset, yoffset, limit); } } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java index a832238..7822258 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java @@ -88,6 +88,7 @@ public void batchDotProduct( int aColumnOffset, int bColumnOffset, int columnLength, + int rRowOffset, int bRowOffset, int rowChunkSize ) { @@ -99,7 +100,9 @@ public void batchDotProduct( int aOffset = at.getOffset(0, aColumnOffset); int bOffset = bt.getOffset(bt.shape().sparseRowOffset(), bColumnOffset); - int rOffset = result.shape().sparseColumnOffset() - bt.shape().sparseRowOffset(); + //Adjusts for both sparse columns and rows this goes negative because we subtract the row offset + //And the row offsets need to add to the result offset + int rOffset = result.shape().sparseColumnOffset() - bt.shape().sparseRowOffset() - rRowOffset; int adjBRowOffset = bRowOffset - bt.shape().sparseRowOffset(); @@ -417,8 +420,8 @@ public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, } @Override - public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) { - delegate.saxpy(alpha, x, y, xoffset, yoffset, limit, batchSize); + public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int rOffset, int xOffset, int batchSize) { + delegate.saxpy(alpha, x, y, xoffset, yoffset, limit, rOffset, xOffset, batchSize); } @Override diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java index 733ce7d..ea089df 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java @@ -188,9 +188,7 @@ public void onNext(GenerateResponse generateResponse) { logger.info("Processing token {} at position {} for session {}", token, position, session); - AbstractTensor output = model.forward(token, position, kvBufferCache.getKvBuffer(session), Optional.of((a, b) -> { - return null; - }), Optional.of(t -> { + AbstractTensor output = model.forward(token, position, kvBufferCache.getKvBuffer(session), Optional.of(t -> { CombineRequest.Builder nrb = CombineRequest.newBuilder() .setUuid(generateResponse.getSession()) .setWorkerid(workerIdBytes) diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index bb39f59..6d4db9d 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -231,15 +231,15 @@ public void MixtralRun() throws Exception { @Test public void GemmaRun() throws Exception { - String modelPrefix = "../models/gemma-7b-it"; + String modelPrefix = "../models/Yi-Coder-1.5B-Chat"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { - GemmaTokenizer tokenizer = new GemmaTokenizer(Paths.get(modelPrefix)); - GemmaConfig c = om.readValue(new File(modelPrefix + "/config.json"), GemmaConfig.class); - GemmaModel model = new GemmaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); - String prompt = "Tell me a joke."; + LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); + LlamaConfig c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); + String prompt = "Write a java function that takes a list of integers and returns the sum of all the integers in the list."; PromptContext p = model.promptSupport().get().builder().addUserMessage(prompt).build(); - model.generate(UUID.randomUUID(), p, 0.3f, 256, makeOutHandler()); + model.generate(UUID.randomUUID(), p, 0.3f, c.contextLength, makeOutHandler()); } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java index 62aee3c..aec31f2 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java @@ -452,6 +452,28 @@ public void testBatchDotProduct() { } } + @Test + public void testBatchDotProductWithResultOffset() { + // M == BATCH, N == ROWS, K == SIZE + + FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS * 2); + FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS * 2); + + FloatBufferTensor a = makeWeights(BATCH, SIZE); // a + FloatBufferTensor b = makeWeights(ROWS, SIZE); // b + + controlOps.batchDotProduct(c, a, b, 0, 0, SIZE, 0, 0, ROWS); + controlOps.batchDotProduct(c, a, b, 0, 0, SIZE, ROWS, 0, ROWS); + + + for (TensorOperations t : opTypes) { + c1.clear(); + t.batchDotProduct(c1, a, b, 0, 0, SIZE, 0, 0, ROWS); + t.batchDotProduct(c1, a, b, 0, 0, SIZE, ROWS, 0, ROWS); + Assert.assertEquals(t.name(), controlOps.sum(c), controlOps.sum(c1), controlOps.sum(c) * 0.01); + } + } + @Test public void testNativeBatchDotProduct() { // M == BATCH, N == ROWS, K == SIZE