diff --git a/Dockerfile b/Dockerfile index dd27912..3251906 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,9 @@ COPY mvnw . RUN --mount=type=cache,target=/root/.m2 ./mvnw clean package FROM openjdk:21-slim +RUN apt-get update +RUN apt-get install -y procps curl + COPY inlinerules.json inlinerules.json COPY run-cli.sh run-cli.sh COPY conf/logback.xml logback.xml diff --git a/docker-compose.yaml b/docker-compose.yaml index a932fbe..e1134e1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -15,19 +15,25 @@ services: cpus: 1 command: - cluster-coordinator - - --threads=4 - - --worker-count=16 + - --threads=2 + - --worker-count=8 - /models/Llama-2-7b-chat-hf-jlama-Q4/ volumes: - "./models:/models" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/ui/index.html"] + interval: 5s jlama-worker: image: jlama restart: always + depends_on: + jlama-coordinator: + condition: service_healthy environment: - JLAMA_JVM_ARGS_EXTRA=-Xmx500M -Djava.net.preferIPv4Stack=true -Djlama.use_hostname_as_workerid=true deploy: mode: replicated - replicas: 16 + replicas: 8 resources: limits: cpus: 1 diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 43e6598..218f8aa 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -252,7 +252,7 @@ public void generate(UUID sessionId, String prompt, String cleanPrompt, float te if (ntokens > c.contextLength) ntokens = c.contextLength; - AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), ntokens, 2, c.embeddingLength); //k and v are last 2 dims + AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), ntokens, 2, c.kvLength); //k and v are last 2 dims AbstractTensor logits = makeTensor(c.vocabularySize); int[] promptTokens = new int[useEOS ? (1 + encoded.length + 1) : (1 + encoded.length)]; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index 0f6b18c..28d2c36 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -2,21 +2,13 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.Config; -import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.operations.TensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; -import com.github.tjake.jlama.util.Pair; import com.google.common.base.Preconditions; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.function.BiFunction; +import java.util.*; import java.util.function.Consumer; -import java.util.function.Function; public class CausalSelfAttention { private final AbstractModel m; @@ -34,7 +26,6 @@ public class CausalSelfAttention { private final AbstractTensor outputProjectionWeights; - private final Optional ropeFrequencies; private final float attentionScale; @@ -45,21 +36,21 @@ public class CausalSelfAttention { public CausalSelfAttention(AbstractModel m, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, - AbstractTensor outputProjectionWeights, Optional ropeFrequencies) + AbstractTensor outputProjectionWeights) { - this(m, Optional.empty(), Optional.empty(), Optional.empty(), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.empty(), outputProjectionWeights, ropeFrequencies); + this(m, Optional.empty(), Optional.empty(), Optional.empty(), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.empty(), outputProjectionWeights); } public CausalSelfAttention(AbstractModel m, AbstractTensor queryAttnBias, AbstractTensor keyAttnBias, AbstractTensor valueAttnBias, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, - AbstractTensor outputProjectionBias, AbstractTensor outputProjectionWeights, Optional ropeFrequencies) { - this(m, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.of(outputProjectionBias), outputProjectionWeights, ropeFrequencies); + AbstractTensor outputProjectionBias, AbstractTensor outputProjectionWeights) { + this(m, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.of(outputProjectionBias), outputProjectionWeights); } public CausalSelfAttention(AbstractModel m, Optional queryAttnBias, Optional keyAttnBias, Optional valueAttnBias, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, - Optional outputProjectionBias, AbstractTensor outputProjectionWeights, Optional ropeFrequencies) + Optional outputProjectionBias, AbstractTensor outputProjectionWeights) { this.m = m; this.c = m.c; @@ -75,7 +66,6 @@ public CausalSelfAttention(AbstractModel m, Optional queryAttnBi this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize)); - this.ropeFrequencies = ropeFrequencies; this.flashAttnHeads = new float[c.contextLength][c.numberOfHeads]; this.qkvResults = new AbstractTensor[3]; @@ -88,8 +78,8 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor try (AbstractTensor flashAttn_m = m.makeTensor(c.numberOfHeads); AbstractTensor flashAttn_l = m.makeTensor(c.numberOfHeads); AbstractTensor query = m.makeFullTensor(c.embeddingLength); - AbstractTensor tmpKey = m.makeFullTensor(c.embeddingLength); - AbstractTensor tmpVal = m.makeFullTensor(c.embeddingLength); + AbstractTensor tmpKey = m.makeFullTensor(c.kvLength); + AbstractTensor tmpVal = m.makeFullTensor(c.kvLength); AbstractTensor value = m.makeFullTensor(c.embeddingLength)) { //This is our memory of the key and value vectors for each position @@ -98,63 +88,105 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor AbstractTensor key = kvp.slice(0); AbstractTensor val = kvp.slice(1); - - qkvResults[0] = query; - qkvResults[1] = tmpKey; - qkvResults[2] = tmpVal; - - // compute the query vector - VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> { - TensorOperationsProvider.get().dotProductBatchChunk(qkvResults, input, qkvWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength); - }); - + if (c.isGQA) { + VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> { + TensorOperationsProvider.get().dotProductChunk(query, input, queryAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength); + }); + VectorMath.pchunk(0, c.kvLength, (chunkStart, chunkLength) -> { + TensorOperationsProvider.get().dotProductChunk(tmpKey, input, keyAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength); + TensorOperationsProvider.get().dotProductChunk(tmpVal, input, valueAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength); + }); + } else { + qkvResults[0] = query; + qkvResults[1] = tmpKey; + qkvResults[2] = tmpVal; + + // compute the query vector + VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> { + TensorOperationsProvider.get().dotProductBatchChunk(qkvResults, input, qkvWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength); + }); + } // For distributed sum of tensor tensorReducer.ifPresent(func -> func.accept(List.of(query, tmpKey, tmpVal))); queryAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(query, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength())); - keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpKey, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength())); - valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpVal, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength())); + keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpKey, bias, c.kvSegmentStart(), c.kvSegmentLength())); + valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpVal, bias, c.kvSegmentStart(), c.kvSegmentLength())); - key.copyFrom(tmpKey, tmpKey.getOffset(c.embeddingSegmentStart()), key.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength()); - val.copyFrom(tmpVal, tmpVal.getOffset(c.embeddingSegmentStart()), val.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength()); + key.copyFrom(tmpKey, tmpKey.getOffset(c.kvSegmentStart()), key.getOffset(c.kvSegmentStart()), c.kvSegmentLength()); + val.copyFrom(tmpVal, tmpVal.getOffset(c.kvSegmentStart()), val.getOffset(c.kvSegmentStart()), c.kvSegmentLength()); // apply RoPE if present (accounting for huggingface permutation) // https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 - ropeFrequencies.ifPresent(rf -> { + c.ropeFreqs.ifPresent(rf -> { int headPiece = c.headSize / 2; int poffset = position * headPiece; - // apply RoPE rotation to the q and k vectors for each head - for (int h = c.headStart(); h < c.headEnd(); h++) { - // get the q and k vectors for this head - int offset = h * c.headSize; - // rotate q and k by the freq theta and freq r - for (int i = offset; i < (offset + headPiece); i++) { - float q0 = query.get(i); - float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc... - float k0 = key.get(i); - float k1 = key.get(i + headPiece); - float[] f = rf[poffset + i]; - float fcr = f[0]; - float fci = f[1]; - query.set(q0 * fcr - q1 * fci, i); - query.set(q0 * fci + q1 * fcr, i + headPiece); - key.set(k0 * fcr - k1 * fci, i); - key.set(k0 * fci + k1 * fcr, i + headPiece); + + if (c.isGQA) { + // apply RoPE rotation to the q and k vectors for each head + for (int h = c.headStart(); h < c.headEnd(); h++) { + // get the q vectors for this head + int offset = h * c.headSize; + // rotate q by the freq theta and freq r + for (int i = offset; i < (offset + headPiece); i++) { + float q0 = query.get(i); + float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc... + float[] f = rf[poffset + i]; + float fcr = f[0]; + float fci = f[1]; + query.set(q0 * fcr - q1 * fci, i); + query.set(q0 * fci + q1 * fcr, i + headPiece); + } + } + + //float[][] grf = c.groupRopeFreqs.get(); + for (int h = c.groupHeadStart(); h < c.groupHeadEnd(); h++) { + // get the k vectors for this head + int offset = h * c.headSize; + // rotate k by the freq theta and freq r + for (int i = offset; i < (offset + headPiece); i++) { + float k0 = key.get(i); + float k1 = key.get(i + headPiece); //hf permutation is 0,64,1,65 etc... + float[] f = rf[poffset + i]; + float fcr = f[0]; + float fci = f[1]; + key.set(k0 * fcr - k1 * fci, i); + key.set(k0 * fci + k1 * fcr, i + headPiece); + } + } + } else { + // apply RoPE rotation to the q and k vectors for each head + for (int h = c.headStart(); h < c.headEnd(); h++) { + // get the q and k vectors for this head + int offset = h * c.headSize; + // rotate q and k by the freq theta and freq r + for (int i = offset; i < (offset + headPiece); i++) { + float q0 = query.get(i); + float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc... + float k0 = key.get(i); + float k1 = key.get(i + headPiece); + float[] f = rf[poffset + i]; + float fcr = f[0]; + float fci = f[1]; + query.set(q0 * fcr - q1 * fci, i); + query.set(q0 * fci + q1 * fcr, i + headPiece); + key.set(k0 * fcr - k1 * fci, i); + key.set(k0 * fci + k1 * fcr, i + headPiece); + } } } }); // with all key-value entries populated, compute attention // the softmax is incrementally aggregated using the flash attention technique - AbstractTensor k0 = kvMem.slice(true, 0).slice(0); + AbstractTensor k0 = kvMem.slice(true,0).slice(0); AbstractTensor v0 = kvMem.slice(true,0).slice(1); // value is initially the position 0 value for all heads - value.copyFrom(v0, v0.getOffset(c.embeddingSegmentStart()), value.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength()); - //POSITION ZERO for (int i = c.headStart(); i < c.headEnd(); i++) { - float a = TensorOperationsProvider.get().dotProduct(query, k0, i * c.headSize, i * c.headSize, c.headSize) * attentionScale; + value.copyFrom(v0, v0.getOffset(c.maybeMapToGroupHead(i) * c.headSize), value.getOffset(i * c.headSize), c.headSize); + float a = TensorOperationsProvider.get().dotProduct(query, k0, i * c.headSize, c.maybeMapToGroupHead(i) * c.headSize, c.headSize) * attentionScale; flashAttn_m.set(a, i); flashAttn_l.set(1, i); } @@ -165,7 +197,7 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor //KEY OFFSET AbstractTensor kk = kvMem.slice(true, i + 1).slice(0); for(int h = c.headStart(); h < c.headEnd(); h++){ - flashAttnHeads[i][h] = TensorOperationsProvider.get().dotProduct(query, kk, h * c.headSize, h * c.headSize, c.headSize) * attentionScale; + flashAttnHeads[i][h] = TensorOperationsProvider.get().dotProduct(query, kk, h * c.headSize, c.maybeMapToGroupHead(h) * c.headSize, c.headSize) * attentionScale; } }); @@ -177,12 +209,12 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor float a = flashAttnHeads[i][h]; if (a > flashAttn_m.get(h)) { float e = (float) Math.exp(flashAttn_m.get(h) - a); - TensorOperationsProvider.get().sxpby(e, vv, value, (h * c.headSize), h * c.headSize, c.headSize); + TensorOperationsProvider.get().sxpby(e, vv, value, c.maybeMapToGroupHead(h) * c.headSize, h * c.headSize, c.headSize); flashAttn_l.set(1 + e * flashAttn_l.get(h), h); flashAttn_m.set(a, h); } else { float e = (float) Math.exp(a - flashAttn_m.get(h)); - TensorOperationsProvider.get().saxpy(e, vv, value, (h * c.headSize), h * c.headSize, c.headSize); + TensorOperationsProvider.get().saxpy(e, vv, value, c.maybeMapToGroupHead(h) * c.headSize, h * c.headSize, c.headSize); flashAttn_l.set(flashAttn_l.get(h) + e, h); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java index 11c09a7..c18996b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java @@ -13,6 +13,6 @@ public BertConfig( @JsonProperty("max_position_embeddings") int contextLength, @JsonProperty("num_hidden_layers") int numberOfLayers, @JsonProperty("layer_norm_eps") float layerNormEps, @JsonProperty("vocab_size") int vocabularySize) { - super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, 0, 0); + super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, 0, 0, null); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index a50eb7f..50bef8c 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -69,7 +69,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { CausalSelfAttention attention = new CausalSelfAttention(this, keyBias, queryBias, valueBias, keyWeight, queryWeight, valueWeight, - outputBias, outputWeight, Optional.empty()); + outputBias, outputWeight); prefix = b; MLPBlock mlpBlock = new MLPBlock(this, ActivationFunction.Type.GELU, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java index 023f75e..2b81d2b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java @@ -15,6 +15,6 @@ public GPT2Config( @JsonProperty("n_ctx") int contextLength, @JsonProperty("vocab_size") int vocabularySize, @JsonProperty("bos_token_id") int bosToken, @JsonProperty("eos_token_id") int eosToken) { - super(contextLength, embeddingLength, embeddingLength * 4, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken); + super(contextLength, embeddingLength, embeddingLength * 4, numberOfHeads, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken, null); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java index 9415c12..2d2a78b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java @@ -50,7 +50,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { AbstractTensor[] attnWeights = weights.load(prefix + "c_attn.weight").transpose().split(3, 0); CausalSelfAttention attention = new CausalSelfAttention(this, attnBias[0], attnBias[1], attnBias[2], attnWeights[0], attnWeights[1], attnWeights[2], - weights.load(prefix + "c_proj.bias"), weights.load(prefix + "c_proj.weight").transpose(), Optional.empty()); + weights.load(prefix + "c_proj.bias"), weights.load(prefix + "c_proj.weight").transpose()); prefix = b + "mlp."; MLPBlock mlpBlock = new MLPBlock(this, ActivationFunction.Type.GELU, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java index b49d43c..987af17 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java @@ -10,11 +10,12 @@ public class LlamaConfig extends Config { public LlamaConfig( @JsonProperty("hidden_size") int embeddingLength, @JsonProperty("intermediate_size") int hiddenLength, @JsonProperty("num_attention_heads") int numberOfHeads, + @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, @JsonProperty("num_hidden_layers") int numberOfLayers, @JsonProperty("rms_norm_eps") float layerNormEps, @JsonProperty("vocab_size") int vocabularySize, @JsonProperty("bos_token_id") int bosToken, - @JsonProperty("eos_token_id") int eosToken ) { - super(2048, embeddingLength, hiddenLength, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken); + @JsonProperty("eos_token_id") int eosToken) { + super(2048, embeddingLength, hiddenLength, numberOfHeads, numberOfKeyValueHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken, 10000.0); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index 8793a76..5e5c926 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -52,8 +52,6 @@ protected TransformerBlock[] loadTransformerBlockWeights() { TransformerBlock[] transformerBlocks = new TransformerBlock[c.getNumberOfLayers()]; - float[][] ropeFreqs = VectorMath.precomputeFreqsCis(c.embeddingLength / c.numberOfHeads, c.contextLength, 10000.0 ); - IntStream.range(c.layerStart(), c.layerEnd()).parallel().forEach(i -> { String base = "model.layers." + i + "."; String prefix = base + "self_attn."; @@ -61,8 +59,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { weights.load(prefix + "q_proj.weight", c.offset()).quantize(qType), weights.load(prefix + "k_proj.weight", c.offset()).quantize(qType), weights.load(prefix + "v_proj.weight", c.offset()).quantize(qType), - weights.load(prefix + "o_proj.weight", c.offset()).quantize(qType), - Optional.of(ropeFreqs)); + weights.load(prefix + "o_proj.weight", c.offset()).quantize(qType)); prefix = base + "mlp."; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index 90eb9ef..20ec5d0 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -1,5 +1,6 @@ package com.github.tjake.jlama.safetensors; +import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.tensor.TensorCache; import com.github.tjake.jlama.util.Pair; @@ -14,13 +15,20 @@ public class Config { public final int embeddingLength; public final int hiddenLength; public final int numberOfHeads; + public final int numberOfKeyValueHeads; public final int headSize; + public final int headGroupSize; + public final int kvLength; + public final boolean isGQA; protected final int numberOfLayers; public final float layerNormEps; public final int vocabularySize; public final int bosToken; public final int eosToken; + public final Optional ropeFreqs; + public final Optional groupRopeFreqs; + private volatile Optional> offset; private volatile File workingDirectory; @@ -28,8 +36,13 @@ public class Config { private volatile int embeddingSegmentStart; private volatile int embeddingSegmentLength; private volatile int embeddingSegmentEnd; + private volatile int kvSegmentStart; + private volatile int kvSegmentLength; + private volatile int kvSegmentEnd; private volatile int headStart; private volatile int headEnd; + private volatile int groupHeadStart; + private volatile int groupHeadEnd; public final TensorCache tensorCache; @@ -37,15 +50,18 @@ public Config(int contextLength, int embeddingLength, int hiddenLength, int numberOfHeads, + int numberOfKeyValueHeads, int numberOfLayers, float layerNormEps, int vocabularySize, int bosToken, - int eosToken) { + int eosToken, + Double ropeFreqsTheta) { this.contextLength = contextLength; this.embeddingLength = embeddingLength; this.hiddenLength = hiddenLength; this.numberOfHeads = numberOfHeads; + this.numberOfKeyValueHeads = numberOfKeyValueHeads; this.numberOfLayers = numberOfLayers; this.layerNormEps = layerNormEps; this.vocabularySize = vocabularySize; @@ -53,6 +69,16 @@ public Config(int contextLength, this.eosToken = eosToken; this.tensorCache = TensorCache.instance; this.headSize = embeddingLength / numberOfHeads; + this.headGroupSize = numberOfHeads / numberOfKeyValueHeads; + this.kvLength = numberOfKeyValueHeads * headSize; + this.isGQA = numberOfKeyValueHeads < numberOfHeads; + if (ropeFreqsTheta != null) { + this.ropeFreqs = Optional.of(VectorMath.precomputeFreqsCis(embeddingLength / numberOfHeads, contextLength, ropeFreqsTheta)); + this.groupRopeFreqs = Optional.of(VectorMath.precomputeFreqsCis(embeddingLength / numberOfKeyValueHeads, contextLength, ropeFreqsTheta)); + } else { + this.ropeFreqs = Optional.empty(); + this.groupRopeFreqs = Optional.empty(); + } setOffset(null); } @@ -62,8 +88,13 @@ public void setOffset(Pair offset) { this.embeddingSegmentStart = this.offset.map(Pair::left).orElse(0); this.embeddingSegmentLength = this.offset.map(Pair::right).orElse(embeddingLength); this.embeddingSegmentEnd = embeddingSegmentStart + embeddingSegmentLength; + this.kvSegmentStart = embeddingSegmentStart / headGroupSize; + this.kvSegmentEnd = embeddingSegmentEnd / headGroupSize; + this.kvSegmentLength = embeddingSegmentLength / headGroupSize; this.headStart = embeddingSegmentStart / headSize; this.headEnd = embeddingSegmentEnd / headSize; + this.groupHeadStart = kvSegmentStart / headSize; + this.groupHeadEnd = kvSegmentEnd / headSize; } public void setWorkingDirectory(File workingDirectory) { @@ -104,6 +135,14 @@ public int embeddingSegmentLength() { return embeddingSegmentLength; } + public int kvSegmentStart() { + return kvSegmentStart; + } + + public int kvSegmentLength() { + return kvSegmentLength; + } + public int headStart() { return headStart; } @@ -111,4 +150,18 @@ public int headStart() { public int headEnd() { return headEnd; } + + public int maybeMapToGroupHead(int head) { + int i = (int) Math.floor((double) head / headGroupSize); + //System.out.println("i: " + i + " head: " + head); + return i; + } + + public int groupHeadStart() { + return groupHeadStart; + } + + public int groupHeadEnd() { + return groupHeadEnd; + } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 80412d0..8eac081 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -80,7 +80,7 @@ public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int bof case AVX_256 -> QDotProductI8Q4_256((Q8ByteBufferTensor) a, (Q4ByteBufferTensor) b, aoffset, boffset, limit); default -> throw new UnsupportedOperationException(); }; - default -> throw new UnsupportedOperationException(); + default -> throw new UnsupportedOperationException(b.dType().name()); }; case BF16 -> switch (b.dType()) { case F32 -> switch (vectorType) { diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java index bc5f3fa..46b0ff4 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java @@ -38,8 +38,7 @@ final class RuntimeHelper { final static SegmentAllocator CONSTANT_ALLOCATOR = (size, align) -> Arena.ofAuto().allocate(size, align); - static - { + static { if (!JarSupport.maybeLoadLibrary()) { System.loadLibrary("jlama"); } diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java index 42483d4..22fe2be 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java @@ -106,7 +106,7 @@ public void testNorm() { class MockConfig extends Config { public MockConfig(int contextLength, int embeddingLength, int hiddenLength, int numberOfHeads, int numberOfLayers, float layerNormEps) { - super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfLayers, layerNormEps, 32000, 1, 2); + super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfHeads, numberOfLayers, layerNormEps, 32000, 1, 2); } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index f3c30e5..e266a00 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -94,6 +94,19 @@ public void LlamaRun() throws Exception { } } + @Test + public void MisralRun() throws Exception { + String modelPrefix = "../models/Mistral-7B-v0.1"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { + LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); + Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); + String prompt = "Simply put, the theory of relativity states that"; + model.generate(UUID.randomUUID(), prompt, 0.7f, 256, false, makeOutHandler()); + } + } + @Test public void testQuantize() throws Exception {