From e1264f3888771288ff952184e2f407f1effee446 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sat, 21 Oct 2023 20:57:18 -0400 Subject: [PATCH] Fix attention bug and other bugs --- .../tjake/jlama/model/AbstractModel.java | 6 ++++ .../jlama/model/CausalSelfAttention.java | 7 +++-- .../github/tjake/jlama/model/MLPBlock.java | 3 +- .../tjake/jlama/model/bert/BertModel.java | 2 +- .../tjake/jlama/model/llama/LlamaModel.java | 11 +++---- .../operations/PanamaTensorOperations.java | 29 ++++++++----------- .../github/tjake/jlama/models/TestModels.java | 9 +++--- 7 files changed, 35 insertions(+), 32 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 260b19e..1bff3ef 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -34,15 +34,21 @@ public abstract class AbstractModel { protected final DType modelDType; protected final DType workingDType; protected final DType workingQType; + protected final Optional modelQType; private static final ThreadLocal tmpArray = new ThreadLocal<>(); protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType) + { + this(c, w, t, workingMemoryDType, workingMemoryQType, Optional.empty()); + } + protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType, Optional modelQType) { this.c = c; this.weights = w; this.tokenizer = t; this.modelDType = w.getModelDType(); this.workingDType = workingMemoryDType; + this.modelQType = modelQType; if (workingMemoryQType != workingMemoryDType) { boolean supportsQType; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index 740459b..a9eaadf 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -92,7 +92,7 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor queryAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(query, bias)); keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(key, bias)); - valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(value, bias)); + valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(val, bias)); // apply RoPE if present (accounting for huggingface permutation) // https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 @@ -122,10 +122,11 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor // with all key-value entries populated, compute attention // the softmax is incrementally aggregated using the flash attention technique - AbstractTensor k0 = kvMem.slice(0).slice(1); + AbstractTensor k0 = kvMem.slice(0).slice(0); + AbstractTensor v0 = kvMem.slice(0).slice(1); // value is initially the first value for all heads - value.copyFrom(k0, 0, 0, c.embeddingLength); + value.copyFrom(v0, 0, 0, c.embeddingLength); //POSITION ZERO for (int i = 0; i < c.numberOfHeads; i++) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java index 77952b3..3ea5697 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java @@ -62,8 +62,9 @@ public AbstractTensor forward(AbstractTensor lnemb) { buf.set(w1a, i); }); - if (upProjectionWeights != null) + if (upProjectionWeights != null) { TensorOperationsProvider.get().maccumulate(buf, buf2); + } //matmul the projection and sum into input AbstractTensor result = model.makeTensor(model.c.embeddingLength); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index 34f8255..4dab730 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -99,7 +99,7 @@ public float[] embed(String input) { long[] encoded = tokenizer.encode(input); Preconditions.checkArgument(encoded.length < c.contextLength); - AbstractTensor kvmem = makeTensor(c.numberOfLayers, encoded.length, c.embeddingLength * 2); //k and v are concatenated + AbstractTensor kvmem = makeTensor(c.numberOfLayers, encoded.length, 2, c.embeddingLength); // 2 for key and value int promptLength = encoded.length; float avgp = 1.0f/promptLength; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index 309494b..4a71e95 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -29,13 +29,14 @@ public class LlamaModel extends AbstractModel { private final AbstractTensor classificationWeights; - public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType) { - super(config, weights, tokenizer, workingDType, workingQType); + public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, DType modelQType) { + super(config, weights, tokenizer, workingDType, workingQType, Optional.ofNullable(modelQType)); - DType qType = DType.Q4; - - logger.info("Quantizing model with {} - Please hold...", qType); + DType qType = modelQType != null ? modelQType : this.modelDType; + if (modelQType != this.modelDType) { + logger.info("Quantizing model with {} - Please hold...", qType); + } this.wte = weights.load("model.embed_tokens.weight").quantize(workingDType); //Don't quantize this, it's used for the embedding layer this.outputLayerNorm = new RMSNorm(this, weights.load("model.norm.weight").quantize(qType)); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index ede2a5a..4cb94b0 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -45,12 +45,12 @@ public String name() { @Override public boolean requiresOffHeapTensor() { - return false; + return true; } @Override public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int boffset, int limit) { - Preconditions.checkArgument(limit % 32 == 0); + Preconditions.checkArgument(limit % 2 == 0, "Limit must be a multiple of 2, not" + limit); return switch (a.dType()) { case F32 -> switch (b.dType()) { @@ -671,11 +671,6 @@ private float dotProductF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int return acc.reduceLanes(VectorOperators.ADD); } - private FloatVector helpF32Q4(FloatVector acc, float scalef, FloatBufferTensor a, Q4ByteBufferTensor b, int aoffset, int boffset) { - - return acc; - } - private float dotProductF32Q4_512(FloatBufferTensor a, Q4ByteBufferTensor b, int aoffset, int boffset, int limit) { Preconditions.checkArgument( boffset % Q4ByteBufferTensor.BLOCK_SIZE == 0 && @@ -1062,10 +1057,10 @@ void maccumulateF32(FloatBufferTensor a, FloatBufferTensor b) { FloatVector vb = b.getVector(FloatVector.SPECIES_PREFERRED, i); a.intoTensor(va.mul(vb), i); } - + // tail for (; i < a.size(); i++) { - a.set(a.get(i) * b.get(i)); + a.set(a.get(i) * b.get(i), i); } } @@ -1103,7 +1098,7 @@ void accumulateF32(FloatBufferTensor a, FloatBufferTensor b) { // tail for (; i < a.size(); i++) { - a.set(a.get(i) + b.get(i)); + a.set(a.get(i) + b.get(i), i); } } @@ -1135,7 +1130,7 @@ void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b) { // tail for (; i < a.size(); i++) { - a.set(a.get(i) + b.get(i)); + a.set(a.get(i) + b.get(i), i); } } @@ -1167,7 +1162,7 @@ void accumulateBF16_512(BFloat16BufferTensor a, BFloat16BufferTensor b) { // tail for (; i < a.size(); i++) { - a.set(a.get(i) + b.get(i)); + a.set(a.get(i) + b.get(i), i); } } @@ -1193,7 +1188,7 @@ public void scale(float factor, AbstractTensor a, int offset, int length) public void scaleF32(float factor, FloatBufferTensor a, int offset, int length) { - int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(offset + length); + int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(length) + offset; int i = offset; FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, factor); @@ -1210,7 +1205,7 @@ public void scaleF32(float factor, FloatBufferTensor a, int offset, int length) public void scaleBF16_512(float factor, BFloat16BufferTensor a, int offset, int length) { - int upperBound = FloatVector.SPECIES_512.loopBound(offset + length); + int upperBound = FloatVector.SPECIES_512.loopBound(length) + offset; int i = offset; FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_512, factor); @@ -1235,7 +1230,7 @@ public void scaleBF16_512(float factor, BFloat16BufferTensor a, int offset, int public void scaleBF16_256(float factor, BFloat16BufferTensor a, int offset, int length) { - int upperBound = FloatVector.SPECIES_256.loopBound(offset + length); + int upperBound = FloatVector.SPECIES_256.loopBound(length) + offset; int i = offset; FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_256, factor); @@ -1261,7 +1256,7 @@ public void scaleBF16_256(float factor, BFloat16BufferTensor a, int offset, int @Override public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) { Preconditions.checkArgument(x.dType() == y.dType()); - Preconditions.checkArgument(limit % 8 == 0); + Preconditions.checkArgument(limit % 2 == 0); switch (x.dType()) { case F32: saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit); break; @@ -1370,7 +1365,7 @@ void saxpyBF16_512(float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, @Override public void sxpby(float beta, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) { Preconditions.checkArgument(x.dType() == y.dType()); - Preconditions.checkArgument(limit % 8 == 0); + Preconditions.checkArgument(limit % 2 == 0); switch (x.dType()) { case F32: sxpbyF32(beta, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit); break; diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java index 97dcabc..993589a 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java @@ -56,7 +56,7 @@ public void GPT2Run() throws IOException { String prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " + "previously unexplored valley, in the Andes Mountains. " + "Even more surprising to the researchers was the fact that the unicorns spoke perfect English."; - gpt2.generate(prompt, 0.6f, 256, false, makeOutHandler()); + gpt2.generate(prompt, 0.8f, 256, false, makeOutHandler()); } } @@ -67,8 +67,7 @@ public void LlamaRun() throws Exception { try (SafeTensorIndex weights = SafeTensorIndex.loadWithWeights(Path.of(modelPrefix))) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8); - + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, DType.Q4); String prompt = "Simply put, the theory of relativity states that"; model.generate(prompt, 0.7f, 256, false, makeOutHandler()); } @@ -85,10 +84,10 @@ public void TinyLlamaRun() throws Exception { Weights weights = SafeTensorSupport.readWeights(bb); LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, DType.F32); String prompt = "Lily picked up a flower and gave it to"; - model.generate(prompt, 0.9f, 128, false, makeOutHandler()); + model.generate(prompt, 0.7f, 128, false, makeOutHandler()); } }