From e7dc2a5578e9a65225092413056b52cbb5dbd647 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 28 Jul 2024 23:55:43 -0400 Subject: [PATCH] Support Mistral Nemo, Llama 3.1, better bf16 support --- .../tjake/jlama/math/FloatConversions.java | 2 - .../github/tjake/jlama/math/VectorMath.java | 4 +- .../tjake/jlama/model/AbstractModel.java | 16 ++++++-- .../jlama/model/CausalSelfAttention.java | 41 ++++++++++++++++--- .../tjake/jlama/model/TransformerBlock.java | 27 +++++++++--- .../tjake/jlama/model/bert/BertModel.java | 2 +- .../tjake/jlama/model/gemma/GemmaModel.java | 2 +- .../tjake/jlama/model/gpt2/GPT2Model.java | 2 +- .../tjake/jlama/model/llama/LlamaConfig.java | 8 ++-- .../tjake/jlama/model/llama/LlamaModel.java | 2 +- .../jlama/model/mistral/MistralConfig.java | 6 ++- .../jlama/model/mixtral/MixtralModel.java | 2 +- .../tjake/jlama/safetensors/Config.java | 37 ++++++++++++++++- .../safetensors/tokenizer/PromptSupport.java | 2 +- .../tjake/jlama/tensor/AbstractTensor.java | 4 +- .../jlama/tensor/BFloat16BufferTensor.java | 28 +++++++++++-- .../tjake/jlama/tensor/FloatBufferTensor.java | 17 ++++++-- .../operations/PanamaTensorOperations.java | 12 ++++-- .../github/tjake/jlama/util/DebugSupport.java | 26 ++++++++++++ .../tjake/jlama/model/TestCorrectness.java | 19 +++++++++ .../github/tjake/jlama/model/TestModels.java | 41 +++++++++++-------- .../src/test/resources/logback-test.xml | 21 ++++++++++ 22 files changed, 261 insertions(+), 60 deletions(-) create mode 100644 jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java create mode 100644 jlama-tests/src/test/resources/logback-test.xml diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java index aaade07..5f7b66b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java @@ -33,8 +33,6 @@ public static float bFloat16ToFloat32(short raw) { } public static short float32ToBFloat16(float n) { - // if (true) - // return (short) ((Float.floatToRawIntBits(n) >> 16) & 0xffff); int nbits = Float.floatToRawIntBits(n); // 32 bits has 1 sign bit, 8 exponent bits, 23 mantissa bits int s = (nbits >>> 16) & 0x8000; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java index 1df41a2..d4c82d6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java @@ -32,7 +32,9 @@ public class VectorMath { public static void pfor(int start, int end, IntConsumer action) { PhysicalCoreExecutor.instance .get() - .execute(() -> IntStream.range(start, end).parallel().forEach(action)); + .execute(() -> IntStream.range(start, end) + .parallel() + .forEach(action)); } public static void pchunk(int offset, int length, BiIntConsumer action) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 75f7bcc..a381c09 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -29,6 +29,7 @@ import com.github.tjake.jlama.tensor.Q8ByteBufferTensor; import com.github.tjake.jlama.tensor.TensorShape; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; +import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.Pair; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; @@ -44,11 +45,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static com.github.tjake.jlama.util.DebugSupport.debug; + public abstract class AbstractModel implements Generator { private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class); - static final boolean DEBUG = false; - public enum InferenceType { INPUT_TO_EMBEDDING(true, false, false), OUTPUT_TO_TOKEN(false, true, false), @@ -113,6 +114,11 @@ protected AbstractModel( workingMemoryQType = DType.F32; } + // FIXME: This is a hack to support Avoid Q8BF16 evals + if (modelDType == DType.BF16 && workingMemoryQType != DType.BF16 && modelQType.isEmpty()) { + workingMemoryQType = DType.BF16; + } + if (workingMemoryQType != workingMemoryDType) { boolean supportsQType; AbstractTensor tmp = makeTensor(Q8ByteBufferTensor.BLOCK_SIZE); @@ -203,6 +209,9 @@ public AbstractTensor forward( Optional>> tensorReducer) { AbstractTensor embedding = embedInput.inputTokenToEmbedding(token_id, pos); + debug("EMBEDDING TOKEN", token_id); + debug("TOKEN POSITION", pos); + for (int i = c.layerStart(); i < c.layerEnd(); i++) { AbstractTensor kvlayer = kvbuf.slice(true, i); AbstractTensor ref = embedding; // reference so we can free @@ -217,7 +226,6 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, Abstrac AbstractTensor last = null; for (int i = 0; i < token_ids.length; i++) { if (last != null) last.close(); - last = forward(token_ids[i], startPos + i, kvbuf); } @@ -329,7 +337,7 @@ public Response generate( long start = System.currentTimeMillis(); long promptStart = start; // Batch Process Prompt - AbstractTensor last = DEBUG + AbstractTensor last = DebugSupport.isDebug() ? batchForwardSlow(promptTokens, startPos, kvmem) : batchForward(promptTokens, startPos, kvmem); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index 6faa0da..a5a966f 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -17,12 +17,18 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.Config; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.BFloat16BufferTensor; +import com.github.tjake.jlama.tensor.FloatBufferTensor; +import com.github.tjake.jlama.tensor.operations.TensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import java.util.*; import java.util.function.Consumer; +import static com.github.tjake.jlama.util.DebugSupport.debug; + public class CausalSelfAttention { private final AbstractModel m; private final Config c; @@ -40,6 +46,8 @@ public class CausalSelfAttention { private final AbstractTensor outputProjectionWeights; private final float attentionScale; + private final int attentionLength; + private final boolean attentionQVSizeMismatch; private final AbstractTensor[] qkvResults; private final AbstractTensor[] qkvWeights; @@ -105,7 +113,9 @@ public CausalSelfAttention( this.outputProjectionBias = outputProjectionBias; this.outputProjectionWeights = outputProjectionWeights; + this.attentionLength = c.numberOfHeads * c.headSize; + this.attentionQVSizeMismatch = c.embeddingLength != attentionLength; this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize)); this.qkvResults = new AbstractTensor[3]; @@ -120,13 +130,13 @@ public AbstractTensor forward( Preconditions.checkArgument(input.dims() == 2 && input.shape().last() == c.embeddingLength); int batchSize = input.shape().first(); - try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, c.embeddingLength); + try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, attentionLength); AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength); AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength); - AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) { + AbstractTensor valueBatch = m.makeFullTensor(batchSize, attentionLength)) { if (c.isGQA) { - VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> { + VectorMath.pchunk(0, attentionLength, (chunkStart, chunkLength) -> { TensorOperationsProvider.get() .dotProductChunk( queryBatch, @@ -186,6 +196,10 @@ public AbstractTensor forward( valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get() .accumulate(tmpValBatch, bias, c.kvSegmentStart(), c.kvSegmentLength())); + debug("query", queryBatch, 0); + debug("key", tmpKeyBatch, 0); + debug("value", tmpValBatch, 0); + // This is our memory of the key and value vectors for each position for (int position = startPosition, bi = 0; position < startPosition + batchSize; position++, bi++) { int finalPostion = position; @@ -222,12 +236,12 @@ public AbstractTensor forward( tmpKey, tmpKey.getOffset(0, c.kvSegmentStart()), key.getOffset(0, c.kvSegmentStart()), - c.kvSegmentLength()); + c.kvLength); val.copyFrom( tmpVal, tmpVal.getOffset(0, c.kvSegmentStart()), val.getOffset(0, c.kvSegmentStart()), - c.kvSegmentLength()); + c.kvLength); } // apply RoPE if present (accounting for huggingface permutation) @@ -241,6 +255,11 @@ public AbstractTensor forward( for (int h = c.headStart(); h < c.headEnd(); h++) { // get the q vectors for this head int offset = h * c.headSize; + + // skip if we are out of bounds + if (offset >= query.shape().last()) + break; + int goffset = c.maybeMapToGroupHead(h) * c.headSize; // rotate q by the freq theta and freq r for (int i = offset, g = goffset; i < (offset + headPiece); i++, g++) { @@ -257,6 +276,8 @@ public AbstractTensor forward( for (int h = c.groupHeadStart(); h < c.groupHeadEnd(); h++) { // get the k vectors for this head int offset = h * c.headSize; + if (offset >= key.shape().last()) + break; // rotate k by the freq theta and freq r for (int i = offset; i < (offset + headPiece); i++) { float k00 = key.get(0, i); @@ -289,14 +310,20 @@ public AbstractTensor forward( } } } + debug("query+rope", query, finalPostion); + debug("key+rope", key, finalPostion); }); + // Attention VectorMath.pfor(c.headStart(), c.headEnd(), h -> { try (AbstractTensor attn = m.makeFullTensor(1, kvp.shape().first())) { int xoffset = c.maybeMapToGroupHead(h) * c.headSize; int yoffset = h * c.headSize; + if (yoffset >= query.shape().last()) + return; + // compute attention scores by multiplying query and key for every position TensorOperationsProvider.get() .batchDotProduct(attn, query, kvp, yoffset, xoffset, c.headSize, 0, finalPostion + 1); @@ -312,6 +339,8 @@ public AbstractTensor forward( }); } + debug("after_attention", valueBatch, 0); + // matmul the projection and sum into input // input += c_proj_weight @ ybuf + c_proj_bias AbstractTensor result = m.makeFullTensor(batchSize, c.embeddingLength); @@ -323,7 +352,7 @@ public AbstractTensor forward( vq, outputProjectionWeights, c.embeddingSegmentStart(), - c.embeddingSegmentLength(), + attentionQVSizeMismatch ? attentionLength : c.embeddingSegmentLength(), chunkStart, chunkSize); }); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java index ec84c4c..9ca040d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java @@ -18,6 +18,7 @@ import com.github.tjake.jlama.model.functions.FeedForward; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; +import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,11 +28,14 @@ import java.util.function.BiFunction; import java.util.function.Consumer; +import static com.github.tjake.jlama.util.DebugSupport.debug; + public class TransformerBlock { private static final Logger logger = LoggerFactory.getLogger(TransformerBlock.class); private final AbstractModel model; + final int layerIndex; final Optional preAttentionNorm; final CausalSelfAttention attention; final LayerNorm postAttentionNorm; @@ -40,11 +44,13 @@ public class TransformerBlock { public TransformerBlock( AbstractModel model, + int layerIndex, LayerNorm preAttentionNorm, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock) { this.model = model; + this.layerIndex = layerIndex; this.preAttentionNorm = Optional.of(preAttentionNorm); this.attention = attention; @@ -56,11 +62,13 @@ public TransformerBlock( public TransformerBlock( AbstractModel model, + int layerIndex, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock, LayerNorm postFFNorm) { this.model = model; + this.layerIndex = layerIndex; this.preAttentionNorm = Optional.empty(); this.attention = attention; @@ -81,38 +89,44 @@ public AbstractTensor forward( Optional>> normReducer, Optional>> tensorReducer) { - if (AbstractModel.DEBUG) - logger.debug("embedding: {}" + embedding); + debug("input_emb", embedding, layerIndex); AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding, normReducer)).orElse(embedding); - if (AbstractModel.DEBUG) - logger.debug("lnemb: {}" + lnemb); + + debug("ln_emb", lnemb, layerIndex); AbstractTensor postAttention; try (AbstractTensor qlnemb = model.maybeQuantize(lnemb)) { postAttention = attention.forward(qlnemb, position, kvBuffer, tensorReducer); } - if (AbstractModel.DEBUG) - logger.debug("postAttention: {}" + postAttention); + debug("post_attn", postAttention, layerIndex); // residual connection TensorOperationsProvider.get() .accumulate( postAttention, embedding, model.c.embeddingSegmentStart(), model.c.embeddingSegmentLength()); + debug("post_attn_res", postAttention, layerIndex); + AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention, normReducer); + + debug("ln_emb2", lnemb2, layerIndex); + AbstractTensor postFF; try (AbstractTensor qlnemb2 = model.maybeQuantize(lnemb2)) { postFF = ffBlock.forward(qlnemb2, tensorReducer); + debug("post_ff", postFF, layerIndex); } // residual connection TensorOperationsProvider.get() .accumulate(postFF, postAttention, model.c.embeddingSegmentStart(), model.c.embeddingSegmentLength()); + debug("post_ff_res", postFF, layerIndex); + // Release any tmp buffers if (lnemb != embedding) lnemb.close(); @@ -122,6 +136,7 @@ public AbstractTensor forward( return postFFNorm .map(ln -> { AbstractTensor lnout = ln.forward(postFF, normReducer); + debug("ln_out", lnout, layerIndex); postFF.close(); return lnout; }) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index d8cbc7f..9e2ef89 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -114,7 +114,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { LayerNorm postMlpNorm = new LayerNorm( this, weights.load(b + "output.LayerNorm.bias"), weights.load(b + "output.LayerNorm.weight")); - transformerBlocks[i] = new TransformerBlock(this, attention, postAttentionNorm, mlpBlock, postMlpNorm); + transformerBlocks[i] = new TransformerBlock(this, i, attention, postAttentionNorm, mlpBlock, postMlpNorm); } return transformerBlocks; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java index c370421..66dca35 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java @@ -97,7 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3 transformerBlocks[i] = new TransformerBlock( - this, + this, i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java index ce680a9..d1f844e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java @@ -97,7 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { LayerNorm layerNorm1 = new LayerNorm(this, weights.load(b + "ln_1.bias"), weights.load(b + "ln_1.weight")); LayerNorm layerNorm2 = new LayerNorm(this, weights.load(b + "ln_2.bias"), weights.load(b + "ln_2.weight")); - transformerBlocks[i] = new TransformerBlock(this, layerNorm1, attention, layerNorm2, mlpBlock); + transformerBlocks[i] = new TransformerBlock(this, i, layerNorm1, attention, layerNorm2, mlpBlock); } return transformerBlocks; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java index c11ce6e..a5a191e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; + +import java.util.List; import java.util.Map; public class LlamaConfig extends Config { @@ -34,7 +36,7 @@ public LlamaConfig( @JsonProperty("rms_norm_eps") float layerNormEps, @JsonProperty("vocab_size") int vocabularySize, @JsonProperty("bos_token_id") int bosToken, - @JsonProperty("eos_token_id") int eosToken, + @JsonProperty("eos_token_id") Object eosToken, @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, @JsonProperty("rope_theta") Double ropeFreqsTheta, @JsonProperty("rope_scaling") Map ropeScaling) { @@ -48,9 +50,9 @@ public LlamaConfig( layerNormEps, vocabularySize, bosToken, - eosToken, + eosToken instanceof List ? ((List)eosToken).get(((List)eosToken).size() - 1) : (Integer) eosToken, //for llama3.1 activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, - ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor"))); + ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor"))); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index abd2e78..d05edfa 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -110,7 +110,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3 transformerBlocks[i] = new TransformerBlock( - this, + this, i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java index bee3973..3dc80f2 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java @@ -34,7 +34,8 @@ public MistralConfig( @JsonProperty("bos_token_id") int bosToken, @JsonProperty("eos_token_id") int eosToken, @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, - @JsonProperty("rope_theta") Double ropeTheta) { + @JsonProperty("rope_theta") Double ropeTheta, + @JsonProperty("head_dim") Integer headSize) { super( contextLength, embeddingLength, @@ -48,6 +49,7 @@ public MistralConfig( eosToken, activationFunction, ropeTheta, - 1.0); + 1.0, + headSize == null ? embeddingLength / numberOfHeads : headSize); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java index 9842901..3b2a060 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java @@ -98,7 +98,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { expertUpWeights); // w3 transformerBlocks[i] = new TransformerBlock( - this, + this, i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index b414fa0..db9833b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -58,6 +58,7 @@ public class Config { public final TensorCache tensorCache; + public Config( int contextLength, int embeddingLength, @@ -72,6 +73,38 @@ public Config( ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor) { + this( + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + eosToken, + activationFunction, + ropeFreqsTheta, + ropeScalingFactor, + embeddingLength / numberOfHeads); + } + + public Config( + int contextLength, + int embeddingLength, + int hiddenLength, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfLayers, + float layerNormEps, + int vocabularySize, + int bosToken, + int eosToken, + ActivationFunction.Type activationFunction, + Double ropeFreqsTheta, + Double ropeScalingFactor, + Integer headSize) { this.contextLength = contextLength; this.embeddingLength = embeddingLength; this.hiddenLength = hiddenLength; @@ -83,7 +116,7 @@ public Config( this.bosToken = bosToken; this.eosToken = eosToken; this.tensorCache = TensorCache.instance; - this.headSize = embeddingLength / numberOfHeads; + this.headSize = headSize; this.headGroupSize = numberOfHeads / numberOfKeyValueHeads; this.kvLength = numberOfKeyValueHeads * headSize; this.isGQA = numberOfKeyValueHeads < numberOfHeads; @@ -91,7 +124,7 @@ public Config( this.ropeFreqs = ropeFreqsTheta == null ? Optional.empty() : Optional.of(VectorMath.precomputeFreqsCis( - embeddingLength / numberOfHeads, + headSize, contextLength, ropeFreqsTheta, ropeScalingFactor == null ? 1.0 : ropeScalingFactor)); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java index 6e9765e..8bece47 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java @@ -151,7 +151,7 @@ public String build() { "eos_token", m.eosToken(), "bos_token", - m.bosToken())); + "")); // We add the BOS ourselves } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java index 02b51e6..e322483 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java @@ -319,10 +319,10 @@ public TensorInfo save(FileChannel out) throws IOException { } public void debug(String id) { - if (false) { + if (true) { double tmp = 0.0; for (int i = 0; i < size(); i++) { - tmp += get(i); + tmp += get(0, i); } System.out.println(String.format("%s = %.5f", id, tmp)); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java index ba9623a..426709d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java @@ -149,13 +149,35 @@ public void clear() { @Override public String toString() { - float[] sample = new float[Math.min(100, b.remaining())]; + float[] sample = new float[Math.min(10, b.remaining())]; for (int i = 0; i < sample.length; i++) { sample[i] = FloatConversions.bFloat16ToFloat32(b.get(i)); } + + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < sample.length; i++) { + sb.append(String.format("%8.4f", sample[i])); + if (i < sample.length - 1) { + sb.append(", "); + } + } + + for (int i = 0; i < sample.length; i++) { + sample[i] = FloatConversions.bFloat16ToFloat32(b.get(i + (shape.first()/2))); + } + + StringBuffer sb2 = new StringBuffer(); + for (int i = 0; i < sample.length; i++) { + sb2.append(String.format("%8.4f", sample[i])); + if (i < sample.length - 1) { + sb2.append(", "); + } + } + + return "BFloat16BufferTensor{" + "name='" + name + '\'' + ", shape=" - + shape + ", b=" - + Arrays.toString(sample) + "...}"; + + shape + ",\n b=" + + sb + "..." + sb2 + "}"; } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java index 6d3f4f4..912a82b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java @@ -17,6 +17,7 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; +import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; @@ -24,6 +25,8 @@ import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.Arrays; +import java.util.List; + import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; import org.slf4j.Logger; @@ -182,11 +185,19 @@ public void clear() { @Override public String toString() { - float[] sample = new float[Math.min(10, b.remaining())]; + float[] sample = new float[DebugSupport.isDebug() ? b.remaining() : Math.min(10, b.remaining())]; b.duplicate().get(sample); + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < sample.length; i++) { + sb.append(String.format("%8.4f", sample[i])); + if (i < sample.length - 1) { + sb.append(", "); + } + } + return "FloatBufferTensor{" + "name='" + name + '\'' + " shape=" - + shape + ", b=" - + Arrays.toString(sample) + "...}"; + + shape + ",\nb={" + + sb + "...}"; } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index b61c01a..5725359 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -100,7 +100,8 @@ public void batchDotProduct( Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2); Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M"); // Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N"); - // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K"); + // This check breaks for GQA + // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K" + a.shape() + " " + b.shape() + " " + columnLength); int M = a.shape().dim(0); int N = rowChunkSize; // b.shape().dim(0); @@ -1721,11 +1722,11 @@ protected BiIntConsumer initMatmul1x1() { FloatVector.SPECIES_PREFERRED, i, aoffset + FloatVector.SPECIES_PREFERRED.length()); ShortVector sb = b.getVector(ShortVector.SPECIES_PREFERRED, j, boffset); - FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0) + FloatVector vb0 = sb.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 0) .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT) .reinterpretAsFloats(); - FloatVector vb1 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1) + FloatVector vb1 = sb.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 1) .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT) .reinterpretAsFloats(); @@ -1821,6 +1822,11 @@ public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int le public BFloat16BufferTensor quantizeBF16(FloatBufferTensor ft, final int offset, int length) { + //Need this till we have a proper quantization + https://github.com/pytorch/pytorch/blob/7c1fbc7fe9cb8ddd5c913b4b3a9e94d00cb055ee/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h#L47 + if (true) + return new BFloat16BufferTensor(ft); + // Up to caller to release BFloat16BufferTensor qft = (BFloat16BufferTensor) TensorCache.instance.get(DType.BF16, ft.shape()); int batchSize = ft.shape().first(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java new file mode 100644 index 0000000..5653301 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java @@ -0,0 +1,26 @@ +package com.github.tjake.jlama.util; + +import com.github.tjake.jlama.tensor.AbstractTensor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class DebugSupport { + private static final boolean DEBUG = false; + private static final Logger logger = LoggerFactory.getLogger(DebugSupport.class); + + public static boolean isDebug() { + return DEBUG; + } + + public static void debug(String name, AbstractTensor t, int layer) { + if (DEBUG) { + logger.debug("Layer: {} - {} - {}", layer, name, t); + } + } + + public static void debug(String msg, Object t) { + if (DEBUG) { + logger.debug("{} - {}", msg, t); + } + } +} diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java index 6708ff2..d412862 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java @@ -225,6 +225,25 @@ public void testGptTokenizer() throws IOException { Assert.assertEquals(p, d); } + @Test + public void testNemoTokenizer() throws IOException { + String modelPrefix = "../models/Mistral-Nemo-Instruct-2407"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + + Tokenizer tokenizer = new GPT2Tokenizer(Paths.get(modelPrefix)); + + String p = "Hello!"; + + long[] actual = tokenizer.encode(p); + long[] expected = new long[] {22177, 1033}; + + String d = tokenizer.decode(actual); + System.out.println(d); + + Assert.assertArrayEquals(expected, actual); + Assert.assertEquals(p, d); + } + @Test public void testNeoTokenizer() throws IOException { String modelPrefix = "../models/deepseek-coder-1.3b-base"; diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index cd4e780..11cf267 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -30,9 +30,13 @@ import com.github.tjake.jlama.model.llama.LlamaConfig; import com.github.tjake.jlama.model.llama.LlamaModel; import com.github.tjake.jlama.model.llama.LlamaTokenizer; +import com.github.tjake.jlama.model.mistral.MistralConfig; +import com.github.tjake.jlama.model.mistral.MistralModel; import com.github.tjake.jlama.model.mixtral.MixtralConfig; import com.github.tjake.jlama.model.mixtral.MixtralModel; import com.github.tjake.jlama.safetensors.*; +import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; +import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; @@ -66,7 +70,7 @@ public class TestModels { static { System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0"); - // System.setProperty("jlama.force_panama_tensor_operations", "true"); + //System.setProperty("jlama.force_panama_tensor_operations", "true"); } private static final Logger logger = LoggerFactory.getLogger(TestModels.class); @@ -94,19 +98,16 @@ public void GPT2Run() throws IOException { @Test public void LlamaRun() throws Exception { - String modelPrefix = "../models/Llama-2-7b-chat-hf-jlama-Q4"; + String modelPrefix = "../models/Meta-Llama-3.1-8B-Instruct"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty()); - String prompt0 = - "Antibiotics are a type of medication used to treat bacterial infections. They work by either killing the bacteria or preventing them from reproducing, " - + "allowing the body’s immune system to fight off the infection. Antibiotics are usually taken orally in the form of pills, capsules, or liquid solutions, " - + "or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance. Explain the above in one sentence:"; - String prompt1 = "The theory of relativity states that"; - model.generate(UUID.randomUUID(), prompt0, 0.7f, 256, false, makeOutHandler()); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); + + String p = model.promptSupport().get().newBuilder().addUserMessage("Tell me a joke.").build(); + model.generate(UUID.randomUUID(), p, 0.3f, 256, false, makeOutHandler()); } } @@ -126,15 +127,21 @@ public void DeepCoderRun() throws Exception { @Test public void MistralRun() throws Exception { - String modelPrefix = "../models/Mistral-7B-v0.1"; + String modelPrefix = "../models/Mistral-Nemo-Instruct-2407"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { - LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); - Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); - String prompt = "Simply put, the theory of relativity states that"; - model.generate(UUID.randomUUID(), prompt, 0.7f, 64, false, makeOutHandler()); + BPETokenizer tokenizer = new GPT2Tokenizer(Paths.get(modelPrefix)); + Config c = om.readValue(new File(modelPrefix + "/config.json"), MistralConfig.class); + MistralModel model = new MistralModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); + String prompt = "Tell me a joke."; + String p = model.promptSupport().isEmpty() ? prompt : + model.promptSupport() + .get() + .newBuilder() + .addUserMessage(prompt) + .build(); + model.generate(UUID.randomUUID(), "[INST] Tell me a joke. [/INST]Assistant", 0.0f, 64, false, makeOutHandler()); } } @@ -171,14 +178,14 @@ public void GemmaRun() throws Exception { SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { GemmaTokenizer tokenizer = new GemmaTokenizer(Paths.get(modelPrefix)); GemmaConfig c = om.readValue(new File(modelPrefix + "/config.json"), GemmaConfig.class); - GemmaModel model = new GemmaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); + GemmaModel model = new GemmaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); String prompt = "Tell me a joke."; String p = model.promptSupport() .get() .newBuilder() .addUserMessage(prompt) .build(); - model.generate(UUID.randomUUID(), p, 0.7f, 256, false, makeOutHandler()); + model.generate(UUID.randomUUID(), p, 0.3f, 256, false, makeOutHandler()); } } diff --git a/jlama-tests/src/test/resources/logback-test.xml b/jlama-tests/src/test/resources/logback-test.xml new file mode 100644 index 0000000..dc1b553 --- /dev/null +++ b/jlama-tests/src/test/resources/logback-test.xml @@ -0,0 +1,21 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + tensor_debug.log + + %msg%n + + + + + + + + + + + \ No newline at end of file