From ea9ac07d12ebf2acd7c4f1455a77ddd58ded4392 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 20 Oct 2024 20:26:18 -0400 Subject: [PATCH] Prep for next release --- .../jlama/cli/commands/ApiServiceCommand.java | 45 ++--- .../jlama/cli/commands/QuantizeCommand.java | 6 +- .../jlama/cli/commands/SimpleBaseCommand.java | 5 +- .../tjake/jlama/math/ActivationFunction.java | 4 +- .../github/tjake/jlama/math/VectorMath.java | 2 +- .../tjake/jlama/model/AbstractModel.java | 13 +- .../tjake/jlama/model/TransformerBlock.java | 177 ++++++++++-------- .../tjake/jlama/model/gemma/GemmaModel.java | 1 - .../jlama/model/gemma2/Gemma2Config.java | 5 +- .../tjake/jlama/model/gemma2/Gemma2Model.java | 1 - .../tjake/jlama/model/llama/LlamaModel.java | 2 +- .../jlama/model/mistral/MistralConfig.java | 3 +- .../tjake/jlama/model/qwen2/Qwen2Config.java | 73 +++++--- .../tjake/jlama/model/qwen2/Qwen2Model.java | 84 +++++---- .../jlama/model/qwen2/Qwen2Tokenizer.java | 28 ++- .../tjake/jlama/safetensors/Config.java | 72 +++---- .../jlama/safetensors/SafeTensorSupport.java | 31 +-- .../operations/PanamaTensorOperations.java | 6 +- .../github/tjake/jlama/util/JsonSupport.java | 2 +- .../tjake/jlama/util/ProgressReporter.java | 15 ++ .../jlama/net/openai/OpenAIChatService.java | 10 +- .../github/tjake/jlama/model/TestModels.java | 2 +- pom.xml | 2 +- 23 files changed, 334 insertions(+), 255 deletions(-) diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java index e6a4039..d5019e7 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java @@ -28,10 +28,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; -import org.springframework.boot.web.server.ConfigurableWebServerFactory; -import org.springframework.boot.web.server.WebServerFactoryCustomizer; -import org.springframework.context.ApplicationContextInitializer; -import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.ConfigurableEnvironment; @@ -67,35 +63,32 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) { public void run() { try { Path modelPath = SimpleBaseCommand.getModel( - modelName, - modelDirectory, - downloadSection.autoDownload, - downloadSection.branch, - downloadSection.authToken); + modelName, + modelDirectory, + downloadSection.autoDownload, + downloadSection.branch, + downloadSection.authToken + ); m = loadModel( - modelPath.toFile(), - workingDirectory, - advancedSection.workingMemoryType, - advancedSection.workingQuantizationType, - Optional.ofNullable(advancedSection.modelQuantization), - Optional.ofNullable(advancedSection.threadCount)); + modelPath.toFile(), + workingDirectory, + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable(advancedSection.threadCount) + ); System.out.println("Chat UI: http://localhost:" + port); System.out.println("OpenAI Chat API: http://localhost:" + port + "/chat/completions"); // Use SpringApplicationBuilder with ApplicationContextInitializer to set the port dynamically - new SpringApplicationBuilder(ApiServiceCommand.class) - .initializers(applicationContext -> { - ConfigurableEnvironment environment = applicationContext.getEnvironment(); - Map props = new HashMap<>(); - props.put("server.port", port); // Set the port here before the server starts - environment.getPropertySources().addFirst(new MapPropertySource("customProps", props)); - }) - .properties("logging.level.org.springframework.web", "info") - .lazyInitialization(true) - .build() - .run(); + new SpringApplicationBuilder(ApiServiceCommand.class).initializers(applicationContext -> { + ConfigurableEnvironment environment = applicationContext.getEnvironment(); + Map props = new HashMap<>(); + props.put("server.port", port); // Set the port here before the server starts + environment.getPropertySources().addFirst(new MapPropertySource("customProps", props)); + }).properties("logging.level.org.springframework.web", "info").lazyInitialization(true).build().run(); } catch (Exception e) { e.printStackTrace(); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java index 1573a76..9156c20 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java @@ -29,10 +29,12 @@ public class QuantizeCommand extends SimpleBaseCommand { @CommandLine.Parameters(index = "1", arity = "0..1", description = "The output location") protected Path output; - @CommandLine.Option(names = { "--quantization" }, paramLabel = "ARG", description = "Model quantization type (default: ${DEFAULT-VALUE})", arity = "1", defaultValue = "Q4") + @CommandLine.Option(names = { + "--quantization" }, paramLabel = "ARG", description = "Model quantization type (default: ${DEFAULT-VALUE})", arity = "1", defaultValue = "Q4") protected DType modelQuantization = DType.Q4; - @CommandLine.Option(names = { "--skip-layer" }, paramLabel = "ARG", description = "Layer name prefix to not quantize (default: ${DEFAULT-VALUE})", defaultValue = "norm") + @CommandLine.Option(names = { + "--skip-layer" }, paramLabel = "ARG", description = "Layer name prefix to not quantize (default: ${DEFAULT-VALUE})", defaultValue = "norm") protected String[] skipLayerPrefixes; @CommandLine.Option(names = { "--drop-layer" }, paramLabel = "ARG", description = "Layer name prefix to drop") diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java index 8ef5f29..dcc3163 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java @@ -26,7 +26,6 @@ import com.github.tjake.jlama.safetensors.SafeTensorSupport; import com.github.tjake.jlama.util.ProgressReporter; -import com.github.tjake.jlama.util.TriConsumer; import com.google.common.util.concurrent.Uninterruptibles; import me.tongfei.progressbar.ProgressBar; import me.tongfei.progressbar.ProgressBarBuilder; @@ -82,7 +81,9 @@ static Optional getProgressConsumer() { return Optional.of((ProgressReporter) (filename, sizeDownloaded, totalSize) -> { if (progressRef.get() == null || !progressRef.get().getTaskName().equals(filename)) { - ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(filename).setInitialMax(totalSize).setStyle(ProgressBarStyle.ASCII); + ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(filename) + .setInitialMax(totalSize) + .setStyle(ProgressBarStyle.ASCII); if (totalSize > 1000000) { builder.setUnit("MB", 1000000); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/ActivationFunction.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/ActivationFunction.java index 540e0f8..fd92cf1 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/ActivationFunction.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/ActivationFunction.java @@ -29,7 +29,9 @@ public enum Type { public static float eval(Type t, float x) { return switch (t) { case SILU -> (float) (x * (1.0f / (1.0f + FastMath.exp(-x)))); - case GELU, GELU_PYTORCH_TANH -> (float) (0.5 * x * (1 + FastMath.tanh(FastMath.sqrt(2 / Math.PI) * (x + 0.044715 * FastMath.pow(x, 3))))); + case GELU, GELU_PYTORCH_TANH -> (float) (0.5 * x * (1 + FastMath.tanh( + FastMath.sqrt(2 / Math.PI) * (x + 0.044715 * FastMath.pow(x, 3)) + ))); case TANH -> (float) FastMath.tanh(x); }; } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java index e53d078..5a4b71a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java @@ -39,7 +39,7 @@ public static void pchunk(int offset, int length, BiIntConsumer action) { int splits = Math.min(length, TensorOperationsProvider.get().parallelSplitSize()); int chunkSize = length / splits; int remainder = 0; - + // Non optimal case, just run in parallel if (splits == 1) { splits = length; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 9684d7f..a585130 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -33,7 +33,6 @@ import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.JsonSupport; -import com.github.tjake.jlama.util.MachineSpec; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; @@ -125,16 +124,18 @@ protected AbstractModel( // Check to make sure the model is big enough to support Q4I8 computations // If not, fall back to F32 - if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 && - ((c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0 || - (c.hiddenLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0)){ + if (modelDType == DType.Q4 + && workingMemoryQType == DType.I8 + && ((c.embeddingLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0 + || (c.hiddenLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0)) { workingMemoryQType = DType.F32; } // Check to make sure the model is big enough to support Q4I8 computations // If not, fall back to F32 - if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 && - (c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0){ + if (modelDType == DType.Q4 + && workingMemoryQType == DType.I8 + && (c.embeddingLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0) { workingMemoryQType = DType.F32; } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java index a964760..0618237 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java @@ -42,94 +42,103 @@ public class TransformerBlock { final Optional preResponseNorm; // After the residual connection public TransformerBlock( - AbstractModel model, - int layerIndex, - LayerNorm preAttentionNorm, - CausalSelfAttention attention, - LayerNorm postAttentionNorm, - FeedForward ffBlock) { + AbstractModel model, + int layerIndex, + LayerNorm preAttentionNorm, + CausalSelfAttention attention, + LayerNorm postAttentionNorm, + FeedForward ffBlock + ) { this( - model, - layerIndex, - Optional.of(preAttentionNorm), - attention, - Optional.empty(), - Optional.of(postAttentionNorm), - ffBlock, - Optional.empty(), - Optional.empty()); + model, + layerIndex, + Optional.of(preAttentionNorm), + attention, + Optional.empty(), + Optional.of(postAttentionNorm), + ffBlock, + Optional.empty(), + Optional.empty() + ); } public TransformerBlock( - AbstractModel model, - int layerIndex, - CausalSelfAttention attention, - LayerNorm postAttentionNorm, - FeedForward ffBlock, - LayerNorm postFFNorm) { + AbstractModel model, + int layerIndex, + CausalSelfAttention attention, + LayerNorm postAttentionNorm, + FeedForward ffBlock, + LayerNorm postFFNorm + ) { this( - model, - layerIndex, - Optional.empty(), - attention, - Optional.empty(), - Optional.of(postAttentionNorm), - ffBlock, - Optional.empty(), - Optional.of(postFFNorm)); + model, + layerIndex, + Optional.empty(), + attention, + Optional.empty(), + Optional.of(postAttentionNorm), + ffBlock, + Optional.empty(), + Optional.of(postFFNorm) + ); } public TransformerBlock( - AbstractModel model, - int layerIndex, - LayerNorm preAttentionNorm, - CausalSelfAttention attention, - LayerNorm postAttentionNorm, - FeedForward ffBlock, - LayerNorm postFFNorm) { + AbstractModel model, + int layerIndex, + LayerNorm preAttentionNorm, + CausalSelfAttention attention, + LayerNorm postAttentionNorm, + FeedForward ffBlock, + LayerNorm postFFNorm + ) { this( - model, - layerIndex, - Optional.of(preAttentionNorm), - attention, - Optional.empty(), - Optional.of(postAttentionNorm), - ffBlock, - Optional.empty(), - Optional.of(postFFNorm)); + model, + layerIndex, + Optional.of(preAttentionNorm), + attention, + Optional.empty(), + Optional.of(postAttentionNorm), + ffBlock, + Optional.empty(), + Optional.of(postFFNorm) + ); } public TransformerBlock( - AbstractModel model, - int layerIndex, - LayerNorm preAttentionNorm, - CausalSelfAttention attention, - LayerNorm postAttentionNorm, - LayerNorm preFFNorm, - FeedForward ffBlock, - LayerNorm postFFNorm) { + AbstractModel model, + int layerIndex, + LayerNorm preAttentionNorm, + CausalSelfAttention attention, + LayerNorm postAttentionNorm, + LayerNorm preFFNorm, + FeedForward ffBlock, + LayerNorm postFFNorm + ) { this( - model, - layerIndex, - Optional.of(preAttentionNorm), - attention, - Optional.of(postAttentionNorm), - Optional.of(preFFNorm), - ffBlock, - Optional.of(postFFNorm), - Optional.empty()); + model, + layerIndex, + Optional.of(preAttentionNorm), + attention, + Optional.of(postAttentionNorm), + Optional.of(preFFNorm), + ffBlock, + Optional.of(postFFNorm), + Optional.empty() + ); } protected TransformerBlock( - AbstractModel model, - int layerIndex, - Optional preAttentionNorm, - CausalSelfAttention attention, - Optional postAttentionNorm, - Optional preFFNorm, - FeedForward ffBlock, - Optional postFFNorm, - Optional preResponseNorm) { + AbstractModel model, + int layerIndex, + Optional preAttentionNorm, + CausalSelfAttention attention, + Optional postAttentionNorm, + Optional preFFNorm, + FeedForward ffBlock, + Optional postFFNorm, + Optional preResponseNorm + ) { this.model = model; this.layerIndex = layerIndex; @@ -147,10 +156,11 @@ public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCa } public AbstractTensor forward( - AbstractTensor embedding, - int position, - KvBufferCache.KvBuffer kvBuffer, - Optional>> tensorReducer) { + AbstractTensor embedding, + int position, + KvBufferCache.KvBuffer kvBuffer, + Optional>> tensorReducer + ) { debug("input_emb", embedding, layerIndex); @@ -190,18 +200,19 @@ public AbstractTensor forward( // Release any tmp buffers (embedding is released by caller) if (lnemb != embedding) lnemb.close(); - if (lnattn != postAttention) lnattn.close(); else postAttention.close(); - if (lnpreFF != lnattn) lnpreFF.close(); else lnattn.close(); + if (lnattn != postAttention) lnattn.close(); + else postAttention.close(); + if (lnpreFF != lnattn) lnpreFF.close(); + else lnattn.close(); return maybeApplyNorm(lnpostFF, preResponseNorm); } private AbstractTensor maybeApplyNorm(AbstractTensor tensor, Optional norm) { return norm.map(ln -> { - AbstractTensor o = ln.forward(tensor); - tensor.close(); - return o; - }).orElse(tensor); + AbstractTensor o = ln.forward(tensor); + tensor.close(); + return o; + }).orElse(tensor); } } - diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java index 106954f..20c075b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java @@ -15,7 +15,6 @@ */ package com.github.tjake.jlama.model.gemma; -import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.model.*; import com.github.tjake.jlama.model.functions.EmbedInput; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java index 378208f..eb4998d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java @@ -56,8 +56,9 @@ public Gemma2Config( activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")), - headDim, - finalLogitSoftCapping, attnLogitSoftCapping + headDim, + finalLogitSoftCapping, + attnLogitSoftCapping ); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java index 5bd3652..dad649e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java @@ -15,7 +15,6 @@ */ package com.github.tjake.jlama.model.gemma2; -import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.model.*; import com.github.tjake.jlama.model.functions.EmbedInput; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index 66582bd..4bc4954 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -71,7 +71,7 @@ protected EmbedInput loadInputWeights() { return (inputToken, position) -> { if (wte.dType() == DType.BF16) { - //Handle old style model with BF16 embeddings + // Handle old style model with BF16 embeddings AbstractTensor embedding = makeDenseTensor(1, c.embeddingLength); AbstractTensor at = wte.slice(true, inputToken); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java index ab7e807..903ae8c 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java @@ -54,7 +54,8 @@ public MistralConfig( 1.0, null, headSize == null ? embeddingLength / numberOfHeads : headSize, - null, null + null, + null ); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java index 2917697..90ca679 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.model.qwen2; import com.fasterxml.jackson.annotation.JsonCreator; @@ -11,35 +26,37 @@ public class Qwen2Config extends Config { @JsonCreator public Qwen2Config( - @JsonProperty("max_position_embeddings") int contextLength, - @JsonProperty("hidden_size") int embeddingLength, - @JsonProperty("intermediate_size") int hiddenLength, - @JsonProperty("num_attention_heads") int numberOfHeads, - @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, - @JsonProperty("num_hidden_layers") int numberOfLayers, - @JsonProperty("rms_norm_eps") float layerNormEps, - @JsonProperty("vocab_size") int vocabularySize, - @JsonProperty("bos_token_id") int bosToken, - @JsonProperty("eos_token_id") int eosToken, - @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, - @JsonProperty("rope_theta") Double ropeTheta) { + @JsonProperty("max_position_embeddings") int contextLength, + @JsonProperty("hidden_size") int embeddingLength, + @JsonProperty("intermediate_size") int hiddenLength, + @JsonProperty("num_attention_heads") int numberOfHeads, + @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, + @JsonProperty("num_hidden_layers") int numberOfLayers, + @JsonProperty("rms_norm_eps") float layerNormEps, + @JsonProperty("vocab_size") int vocabularySize, + @JsonProperty("bos_token_id") int bosToken, + @JsonProperty("eos_token_id") int eosToken, + @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, + @JsonProperty("rope_theta") Double ropeTheta + ) { super( - contextLength, - embeddingLength, - hiddenLength, - numberOfHeads, - numberOfKeyValueHeads, - numberOfLayers, - layerNormEps, - vocabularySize, - bosToken, - List.of(eosToken), - activationFunction, - ropeTheta, - 1.0, - null, - embeddingLength / numberOfHeads, - null, null + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + List.of(eosToken), + activationFunction, + ropeTheta, + 1.0, + null, + embeddingLength / numberOfHeads, + null, + null ); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java index 9e6e3ea..2521e95 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.model.qwen2; import com.github.tjake.jlama.model.*; @@ -17,24 +32,24 @@ public class Qwen2Model extends LlamaModel { private static final Logger logger = LoggerFactory.getLogger(Qwen2Model.class); public Qwen2Model( - Config config, - WeightLoader weights, - Tokenizer tokenizer, - DType workingDType, - DType workingQType, - Optional modelQType + Config config, + WeightLoader weights, + Tokenizer tokenizer, + DType workingDType, + DType workingQType, + Optional modelQType ) { super(config, weights, tokenizer, workingDType, workingQType, modelQType); } public Qwen2Model( - InferenceType inferenceType, - Config config, - WeightLoader weights, - Tokenizer tokenizer, - DType workingDType, - DType workingQType, - Optional modelQType + InferenceType inferenceType, + Config config, + WeightLoader weights, + Tokenizer tokenizer, + DType workingDType, + DType workingQType, + Optional modelQType ) { super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType); } @@ -55,42 +70,41 @@ protected TransformerBlock[] loadTransformerBlockWeights() { String base = "model.layers." + i + "."; String prefix = base + "self_attn."; CausalSelfAttention attention = new CausalSelfAttention( - this, - relativeLayer, - Optional.of(weights.load(prefix + "q_proj.bias").quantize(qType)), - Optional.of(weights.load(prefix + "k_proj.bias").quantize(qType)), - Optional.of(weights.load(prefix + "v_proj.bias").quantize(qType)), - weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), - weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), - weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), - Optional.empty(), - weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType) + this, + relativeLayer, + Optional.of(weights.load(prefix + "q_proj.bias").quantize(qType)), + Optional.of(weights.load(prefix + "k_proj.bias").quantize(qType)), + Optional.of(weights.load(prefix + "v_proj.bias").quantize(qType)), + weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), + Optional.empty(), + weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType) ); prefix = base + "mlp."; MLPBlock mlp = new MLPBlock( - this, - c.activationFunction, - weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1 - weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2 - weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType) + this, + c.activationFunction, + weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1 + weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2 + weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType) ); // w3 transformerBlocks[relativeLayer] = new TransformerBlock( - this, - relativeLayer, - new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType)), - attention, - new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType)), - mlp + this, + relativeLayer, + new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType)), + attention, + new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType)), + mlp ); }); return transformerBlocks; } - @Override public ModelSupport.ModelType getModelType() { return ModelSupport.ModelType.QWEN2; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java index 76be12c..5a08cf0 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.model.qwen2; import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; @@ -8,7 +23,6 @@ public class Qwen2Tokenizer extends BPETokenizer { - public Qwen2Tokenizer(Path modelRoot) { super(modelRoot); } @@ -19,9 +33,9 @@ protected String preProcess(String sentence) { if (model.isLegacy() && !model.byteFallback) { sentence = sentence.codePoints() - .map(c -> alteredBytes.getOrDefault(c, c)) - .mapToObj(Character::toString) - .collect(Collectors.joining()); + .map(c -> alteredBytes.getOrDefault(c, c)) + .mapToObj(Character::toString) + .collect(Collectors.joining()); } return sentence; @@ -43,9 +57,9 @@ protected String postProcessToken(String decoded) { if (model.isLegacy() && !model.byteFallback) { decoded = decoded.codePoints() - .map(c -> alteredBytes.inverse().getOrDefault(c, c)) - .mapToObj(Character::toString) - .collect(Collectors.joining()); + .map(c -> alteredBytes.inverse().getOrDefault(c, c)) + .mapToObj(Character::toString) + .collect(Collectors.joining()); } return decoded; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index 3762fe3..1760edd 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -56,41 +56,41 @@ public class Config { public final TensorCache tensorCache; public Config( - int contextLength, - int embeddingLength, - int hiddenLength, - int numberOfHeads, - int numberOfKeyValueHeads, - int numberOfLayers, - float layerNormEps, - int vocabularySize, - int bosToken, - List eosToken, - ActivationFunction.Type activationFunction, - Double ropeFreqsTheta, - Double ropeScalingFactor, - Integer headSize, - Float attnLogitSoftCapping, - Float finalLogitSoftCapping + int contextLength, + int embeddingLength, + int hiddenLength, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfLayers, + float layerNormEps, + int vocabularySize, + int bosToken, + List eosToken, + ActivationFunction.Type activationFunction, + Double ropeFreqsTheta, + Double ropeScalingFactor, + Integer headSize, + Float attnLogitSoftCapping, + Float finalLogitSoftCapping ) { this( - contextLength, - embeddingLength, - hiddenLength, - numberOfHeads, - numberOfKeyValueHeads, - numberOfLayers, - layerNormEps, - vocabularySize, - bosToken, - eosToken, - activationFunction, - ropeFreqsTheta, - ropeScalingFactor, - null, - headSize == null ? embeddingLength / numberOfHeads : headSize, - attnLogitSoftCapping, - finalLogitSoftCapping + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + eosToken, + activationFunction, + ropeFreqsTheta, + ropeScalingFactor, + null, + headSize == null ? embeddingLength / numberOfHeads : headSize, + attnLogitSoftCapping, + finalLogitSoftCapping ); } @@ -125,7 +125,8 @@ public Config( ropeScalingFactor, null, embeddingLength / numberOfHeads, - null, null + null, + null ); } @@ -161,7 +162,8 @@ public Config( ropeScalingFactor, classifcationLabels, embeddingLength / numberOfHeads, - null, null + null, + null ); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java index ddbc1ad..a9ad3c1 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java @@ -28,7 +28,6 @@ import com.github.tjake.jlama.tensor.Q8ByteBufferTensor; import com.github.tjake.jlama.util.HttpSupport; import com.github.tjake.jlama.util.ProgressReporter; -import com.github.tjake.jlama.util.TriConsumer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import java.io.*; @@ -281,7 +280,6 @@ public static Path quantizeModel( // Copy README.md and add jlama header addJlamaHeader(baseDirName, qPath.resolve("README.md")); - if (Files.exists(modelRoot.resolve("tokenizer_config.json"))) Files.copy( modelRoot.resolve("tokenizer_config.json"), qPath.resolve("tokenizer_config.json") @@ -318,13 +316,14 @@ public void write(byte[] b, int off, int len) throws IOException { private static void addJlamaHeader(String modelName, Path readmePath) throws IOException { String cleanName = modelName.replaceAll("_", "/"); String header = String.format( - "# Quantized Version of %s \n\n" + - "This model is a quantized variant of the %s model, optimized for use with Jlama, a Java-based inference engine. " + - "The quantization process reduces the model's size and improves inference speed, while maintaining high accuracy " + - "for efficient deployment in production environments.\n\n" + - "For more information on Jlama, visit the [Jlama GitHub repository](https://github.com/tjake/jlama).\n\n" + - "---\n\n", - cleanName, cleanName + "# Quantized Version of %s \n\n" + + "This model is a quantized variant of the %s model, optimized for use with Jlama, a Java-based inference engine. " + + "The quantization process reduces the model's size and improves inference speed, while maintaining high accuracy " + + "for efficient deployment in production environments.\n\n" + + "For more information on Jlama, visit the [Jlama GitHub repository](https://github.com/tjake/jlama).\n\n" + + "---\n\n", + cleanName, + cleanName ); String readme = new String(Files.readAllBytes(readmePath)); boolean startMeta = false; @@ -333,7 +332,7 @@ private static void addJlamaHeader(String modelName, Path readmePath) throws IOE StringBuilder finalReadme = new StringBuilder(); for (String line : readme.split("\n")) { if (linenum++ == 0) { - if (line.startsWith("---")) { + if (line.startsWith("---")) { startMeta = true; } else { finalReadme.append(header); @@ -365,9 +364,15 @@ public static File maybeDownloadModel(String modelDir, String fullModelName, Pro name = parts[1]; } - return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, - true, Optional.empty(), Optional.empty(), - Optional.ofNullable(progressReporter)); + return maybeDownloadModel( + modelDir, + Optional.ofNullable(owner), + name, + true, + Optional.empty(), + Optional.empty(), + Optional.ofNullable(progressReporter) + ); } public static File maybeDownloadModel(String modelDir, String fullModelName) throws IOException { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index c265724..abb4181 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -2265,9 +2265,9 @@ void accumulateF32BF16_256(FloatBufferTensor a, BFloat16BufferTensor b, int offs // Convert BF16 to F32 var bf = b.getVector(ShortVector.SPECIES_128, 0, i) - .convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0) - .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256) - .reinterpretAsFloats(); + .convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0) + .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256) + .reinterpretAsFloats(); var res = af.add(bf); a.intoTensor(res, 0, i); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java index 533881d..e682883 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java @@ -31,7 +31,7 @@ */ public class JsonSupport { private static final String JSON_REGEX = - "(\\{\\s*(\"[^\"]+\"\\s*:\\s*(\"[^\"]*\"|\\d+|true|false|null|\\{[^{}]*\\}|\\[[^\\[\\]]*\\])\\s*,?\\s*)+\\})|(\\[\\s*(\\{\\s*(\"[^\"]+\"\\s*:\\s*(\"[^\"]*\"|\\d+|true|false|null|\\{[^{}]*\\}|\\[[^\\[\\]]*\\])\\s*,?\\s*)+\\}\\s*,?\\s*)+\\])"; + "(\\{\\s*(\"[^\"]+\"\\s*:\\s*(\"[^\"]*\"|\\d+|true|false|null|\\{[^{}]*\\}|\\[[^\\[\\]]*\\])\\s*,?\\s*)+\\})|(\\[\\s*(\\{\\s*(\"[^\"]+\"\\s*:\\s*(\"[^\"]*\"|\\d+|true|false|null|\\{[^{}]*\\}|\\[[^\\[\\]]*\\])\\s*,?\\s*)+\\}\\s*,?\\s*)+\\])"; private static final Pattern JSON_PATTERN = Pattern.compile(JSON_REGEX); public static final ObjectMapper om = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES, false) .configure(DeserializationFeature.FAIL_ON_TRAILING_TOKENS, false) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/ProgressReporter.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/ProgressReporter.java index af2846a..af41107 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/ProgressReporter.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/ProgressReporter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.util; public interface ProgressReporter { diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index a669b42..f3d546a 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -101,7 +101,7 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ } } - float temperature = 0.3f; + float temperature = 0.3f; int maxTokens = request.getMaxTokens() == null ? model.getConfig().contextLength : request.getMaxTokens(); logger.info("Generating completion for session {} with temperature {} and max tokens {}", sessionId, temperature, maxTokens); @@ -140,9 +140,11 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ emitter.complete(); - logger.info("{} tokens/s (prompt), {} tokens/s (gen)", - Math.round(r.promptTokens / (double) (r.promptTimeMs / 1000f)), - Math.round(r.generatedTokens / (double) (r.generateTimeMs / 1000f))); + logger.info( + "{} tokens/s (prompt), {} tokens/s (gen)", + Math.round(r.promptTokens / (double) (r.promptTimeMs / 1000f)), + Math.round(r.generatedTokens / (double) (r.generateTimeMs / 1000f)) + ); } catch (IOException e) { emitter.completeWithError(e); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index bf03ea8..5d9f682 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -54,7 +54,7 @@ public class TestModels { static { System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "2"); - //System.setProperty("jlama.force_panama_tensor_operations", "true"); + // System.setProperty("jlama.force_panama_tensor_operations", "true"); } private static final Logger logger = LoggerFactory.getLogger(TestModels.class); diff --git a/pom.xml b/pom.xml index 43ae3e1..6c5b726 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ UTF-8 - 0.6.0 + 0.7.0 2.0.7 1.5.6