From 36d96403a513867477b7fed7896f760fb13f3734 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 25 Aug 2024 23:33:30 -0400 Subject: [PATCH] Run spotless --- .../jlama/cli/commands/ApiServiceCommand.java | 19 +- .../commands/ClusterCoordinatorCommand.java | 2 - .../github/tjake/jlama/math/VectorMath.java | 4 +- .../tjake/jlama/model/AbstractModel.java | 4 +- .../jlama/model/CausalSelfAttention.java | 18 +- .../tjake/jlama/model/TransformerBlock.java | 11 +- .../tjake/jlama/model/bert/BertConfig.java | 1 - .../tjake/jlama/model/gemma/GemmaConfig.java | 1 - .../tjake/jlama/model/gemma/GemmaModel.java | 3 +- .../tjake/jlama/model/gpt2/GPT2Config.java | 1 - .../tjake/jlama/model/gpt2/GPT2Tokenizer.java | 3 +- .../tjake/jlama/model/llama/LlamaConfig.java | 5 +- .../tjake/jlama/model/llama/LlamaModel.java | 4 +- .../jlama/model/llama/LlamaTokenizer.java | 3 +- .../jlama/model/mistral/MistralConfig.java | 1 - .../jlama/model/mixtral/MixtralConfig.java | 1 - .../jlama/model/mixtral/MixtralModel.java | 3 +- .../tjake/jlama/safetensors/Config.java | 6 +- .../jlama/safetensors/prompt/Function.java | 24 +- .../jlama/safetensors/prompt/Parameters.java | 31 ++- .../safetensors/prompt/PromptSupport.java | 30 ++- .../jlama/safetensors/prompt/Result.java | 15 ++ .../tjake/jlama/safetensors/prompt/Tool.java | 20 +- .../jlama/safetensors/prompt/ToolCall.java | 16 +- .../safetensors/tokenizer/BPETokenizer.java | 6 +- .../safetensors/tokenizer/Tokenizer.java | 2 - .../safetensors/tokenizer/TokenizerModel.java | 7 +- .../tjake/jlama/tensor/AbstractTensor.java | 1 - .../jlama/tensor/BFloat16BufferTensor.java | 5 +- .../jlama/tensor/Float16BufferTensor.java | 1 - .../tjake/jlama/tensor/FloatBufferTensor.java | 12 +- .../jlama/tensor/Q4ByteBufferTensor.java | 1 - .../jlama/tensor/Q5ByteBufferTensor.java | 1 - .../jlama/tensor/Q8ByteBufferTensor.java | 1 - .../operations/PanamaTensorOperations.java | 10 +- .../tensor/operations/TensorOperations.java | 2 +- .../github/tjake/jlama/util/DebugSupport.java | 15 ++ .../github/tjake/jlama/util/JsonSupport.java | 3 - .../operations/NativeTensorOperations.java | 9 +- .../tensor/operations/cnative/NativeSimd.java | 214 ++++++++++++++++-- .../operations/util/MemorySegmentSupport.java | 24 +- .../github/tjake/jlama/net/Coordinator.java | 3 +- .../com/github/tjake/jlama/net/Worker.java | 2 - .../jlama/net/openai/OpenAIChatService.java | 97 ++++---- .../tjake/jlama/net/openai/ChatApiTest.java | 56 +++-- .../jlama/net/openai/MockedOpenAIServer.java | 23 +- .../jlama/net/openai/OpenAIServiceTests.java | 20 +- .../tjake/jlama/model/TestCorrectness.java | 124 +++++----- .../github/tjake/jlama/model/TestModels.java | 71 +++--- 49 files changed, 613 insertions(+), 323 deletions(-) diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java index a4dab8d..320edc1 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java @@ -15,12 +15,14 @@ */ package com.github.tjake.jlama.cli.commands; +import static com.github.tjake.jlama.model.ModelSupport.loadModel; + import com.github.tjake.jlama.model.AbstractModel; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.SpringBootApplication; - import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -28,11 +30,9 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import picocli.CommandLine; -import java.util.Optional; - -import static com.github.tjake.jlama.model.ModelSupport.loadModel; - -@CommandLine.Command(name = "restapi", description = "Starts a openai compatible rest api for interacting with this model") +@CommandLine.Command( + name = "restapi", + description = "Starts a openai compatible rest api for interacting with this model") @SpringBootApplication(scanBasePackages = {"com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands"}) @SpringBootConfiguration @Configuration @@ -54,8 +54,7 @@ public AbstractModel getModelBean() { @Override public void addResourceHandlers(ResourceHandlerRegistry registry) { - registry.addResourceHandler("/ui/**") - .addResourceLocations("classpath:/static/ui/"); + registry.addResourceHandler("/ui/**").addResourceLocations("classpath:/static/ui/"); } @Override @@ -74,9 +73,7 @@ public void run() { new SpringApplicationBuilder(ApiServiceCommand.class) .lazyInitialization(true) - .properties( - "server.port", ""+port, - "logging.level.org.springframework.web", "debug") + .properties("server.port", "" + port, "logging.level.org.springframework.web", "debug") .build() .run(); } catch (Exception e) { diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java index dc58458..27d3d2f 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java @@ -15,8 +15,6 @@ */ package com.github.tjake.jlama.cli.commands; - - import com.github.tjake.jlama.net.Coordinator; import picocli.CommandLine; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java index 2991132..72d59a7 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java @@ -32,9 +32,7 @@ public class VectorMath { public static void pfor(int start, int end, IntConsumer action) { PhysicalCoreExecutor.instance .get() - .execute(() -> IntStream.range(start, end) - .parallel() - .forEach(action)); + .execute(() -> IntStream.range(start, end).parallel().forEach(action)); } public static void pchunk(int offset, int length, BiIntConsumer action) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index e88c682..5c30311 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -15,6 +15,8 @@ */ package com.github.tjake.jlama.model; +import static com.github.tjake.jlama.util.DebugSupport.debug; + import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.model.functions.EmbedInput; import com.github.tjake.jlama.model.functions.Generator; @@ -45,8 +47,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static com.github.tjake.jlama.util.DebugSupport.debug; - public abstract class AbstractModel implements Generator { private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index a5a966f..0077776 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -15,20 +15,16 @@ */ package com.github.tjake.jlama.model; +import static com.github.tjake.jlama.util.DebugSupport.debug; + import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.Config; -import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.BFloat16BufferTensor; -import com.github.tjake.jlama.tensor.FloatBufferTensor; -import com.github.tjake.jlama.tensor.operations.TensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import java.util.*; import java.util.function.Consumer; -import static com.github.tjake.jlama.util.DebugSupport.debug; - public class CausalSelfAttention { private final AbstractModel m; private final Config c; @@ -257,8 +253,7 @@ public AbstractTensor forward( int offset = h * c.headSize; // skip if we are out of bounds - if (offset >= query.shape().last()) - break; + if (offset >= query.shape().last()) break; int goffset = c.maybeMapToGroupHead(h) * c.headSize; // rotate q by the freq theta and freq r @@ -276,8 +271,7 @@ public AbstractTensor forward( for (int h = c.groupHeadStart(); h < c.groupHeadEnd(); h++) { // get the k vectors for this head int offset = h * c.headSize; - if (offset >= key.shape().last()) - break; + if (offset >= key.shape().last()) break; // rotate k by the freq theta and freq r for (int i = offset; i < (offset + headPiece); i++) { float k00 = key.get(0, i); @@ -314,15 +308,13 @@ public AbstractTensor forward( debug("key+rope", key, finalPostion); }); - // Attention VectorMath.pfor(c.headStart(), c.headEnd(), h -> { try (AbstractTensor attn = m.makeFullTensor(1, kvp.shape().first())) { int xoffset = c.maybeMapToGroupHead(h) * c.headSize; int yoffset = h * c.headSize; - if (yoffset >= query.shape().last()) - return; + if (yoffset >= query.shape().last()) return; // compute attention scores by multiplying query and key for every position TensorOperationsProvider.get() diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java index 9ca040d..7e20ee8 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java @@ -15,20 +15,18 @@ */ package com.github.tjake.jlama.model; +import static com.github.tjake.jlama.util.DebugSupport.debug; + import com.github.tjake.jlama.model.functions.FeedForward; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; -import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.List; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Consumer; - -import static com.github.tjake.jlama.util.DebugSupport.debug; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class TransformerBlock { @@ -94,7 +92,6 @@ public AbstractTensor forward( AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding, normReducer)).orElse(embedding); - debug("ln_emb", lnemb, layerIndex); AbstractTensor postAttention; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java index d98b9bb..a619c92 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; public class BertConfig extends Config { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java index 4cecc94..58f7039 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; import java.util.Map; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java index 66dca35..a213513 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java @@ -97,7 +97,8 @@ protected TransformerBlock[] loadTransformerBlockWeights() { weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3 transformerBlocks[i] = new TransformerBlock( - this, i, + this, + i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java index d65fdb9..42ac437 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; public class GPT2Config extends Config { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Tokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Tokenizer.java index 7d3d708..b248b7e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Tokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Tokenizer.java @@ -18,13 +18,12 @@ import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; -import net.fellbaum.jemoji.EmojiManager; - import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import net.fellbaum.jemoji.EmojiManager; public class GPT2Tokenizer extends BPETokenizer { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java index 377feaa..90ab919 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; import java.util.Map; @@ -53,6 +52,8 @@ public LlamaConfig( eosToken instanceof List ? (List) eosToken : List.of((Integer) eosToken), activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, - ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor"))); + ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) + ? 1.0 + : Double.parseDouble(ropeScaling.get("factor"))); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index d05edfa..059bd92 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -70,7 +70,6 @@ protected EmbedInput loadInputWeights() { at = TensorOperationsProvider.get() .quantize(at, embedding.dType(), c.embeddingSegmentStart(), c.embeddingSegmentLength()); - embedding.copyFrom( at, at.getOffset(0, c.embeddingSegmentStart()), @@ -110,7 +109,8 @@ protected TransformerBlock[] loadTransformerBlockWeights() { weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3 transformerBlocks[i] = new TransformerBlock( - this, i, + this, + i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java index d3a4e85..b426cce 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java @@ -27,7 +27,8 @@ public class LlamaTokenizer extends BPETokenizer { public LlamaTokenizer(Path modelRoot) { super(modelRoot); - this.byteFallbackEncodingOffset = this.getModel().vocabLookup.getOrDefault("<0x00>", 0L).intValue(); + this.byteFallbackEncodingOffset = + this.getModel().vocabLookup.getOrDefault("<0x00>", 0L).intValue(); } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java index cd6d9f3..3c04da5 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; public class MistralConfig extends Config { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java index fa49faf..9fe59b6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; - import java.util.List; public class MixtralConfig extends Config { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java index 3b2a060..4b3a9ab 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralModel.java @@ -98,7 +98,8 @@ protected TransformerBlock[] loadTransformerBlockWeights() { expertUpWeights); // w3 transformerBlocks[i] = new TransformerBlock( - this, i, + this, + i, new RMSNorm( this, weights.load(base + "input_layernorm.weight", c.offset()) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index 5ce1193..4c011a5 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -59,7 +59,6 @@ public class Config { public final TensorCache tensorCache; - public Config( int contextLength, int embeddingLength, @@ -125,10 +124,7 @@ public Config( this.ropeFreqs = ropeFreqsTheta == null ? Optional.empty() : Optional.of(VectorMath.precomputeFreqsCis( - headSize, - contextLength, - ropeFreqsTheta, - ropeScalingFactor == null ? 1.0 : ropeScalingFactor)); + headSize, contextLength, ropeFreqsTheta, ropeScalingFactor == null ? 1.0 : ropeScalingFactor)); // Set default values setOffset(null); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java index 0106b89..610e590 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java @@ -1,12 +1,24 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors.prompt; import com.fasterxml.jackson.annotation.JsonPropertyOrder; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.hubspot.jinjava.objects.collections.PyMap; - -import javax.annotation.concurrent.Immutable; -import java.util.HashMap; import java.util.Map; /** @@ -61,7 +73,11 @@ public Function build() { } private Function(String name, String description, Parameters parameters) { - super(ImmutableMap.builder().put(JSON_PROPERTY_NAME, name).put(JSON_PROPERTY_DESCRIPTION, description).put(JSON_PROPERTY_PARAMETERS, parameters).build()); + super(ImmutableMap.builder() + .put(JSON_PROPERTY_NAME, name) + .put(JSON_PROPERTY_DESCRIPTION, description) + .put(JSON_PROPERTY_PARAMETERS, parameters) + .build()); this.name = name; this.description = description; this.parameters = parameters; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java index 9946530..f9f12b5 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java @@ -1,18 +1,33 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors.prompt; import com.fasterxml.jackson.annotation.JsonPropertyOrder; import com.google.common.collect.ImmutableMap; import com.hubspot.jinjava.objects.collections.PyMap; - import java.util.*; /** * Parameters */ @JsonPropertyOrder({ - Parameters.JSON_PROPERTY_TYPE, - Parameters.JSON_PROPERTY_PROPERTIES, - Parameters.JSON_PROPERTY_REQUIRED}) + Parameters.JSON_PROPERTY_TYPE, + Parameters.JSON_PROPERTY_PROPERTIES, + Parameters.JSON_PROPERTY_REQUIRED +}) public class Parameters extends PyMap { public static final String JSON_PROPERTY_TYPE = "type"; @@ -27,14 +42,16 @@ public Parameters(Map> properties, List requ super(ImmutableMap.builder() .put(JSON_PROPERTY_TYPE, "object") .put(JSON_PROPERTY_PROPERTIES, properties) - .put(JSON_PROPERTY_REQUIRED, required).build()); + .put(JSON_PROPERTY_REQUIRED, required) + .build()); this.required = required; } public Parameters(Map> properties) { super(ImmutableMap.builder() .put(JSON_PROPERTY_TYPE, "object") - .put(JSON_PROPERTY_PROPERTIES, properties).build()); + .put(JSON_PROPERTY_PROPERTIES, properties) + .build()); this.required = null; } @@ -77,4 +94,4 @@ public Map getProperties() { public List getRequired() { return required; } -} \ No newline at end of file +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java index e55942c..97f60dd 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java @@ -15,8 +15,6 @@ */ package com.github.tjake.jlama.safetensors.prompt; -import com.fasterxml.jackson.core.util.DefaultIndenter; -import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel; @@ -26,10 +24,7 @@ import com.hubspot.jinjava.LegacyOverrides; import com.hubspot.jinjava.interpret.RenderResult; import com.hubspot.jinjava.lib.fn.ELFunctionDefinition; - import java.util.*; - -import com.hubspot.jinjava.objects.serialization.PyishPrettyPrinter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,7 +45,9 @@ public class PromptSupport { .withUseSnakeCasePropertyNaming(true) .withKeepNullableLoopValues(true) .build()) - .withObjectMapper(new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).setDefaultPrettyPrinter(JsonSupport.JlamaPrettyPrinter.INSTANCE)) + .withObjectMapper(new ObjectMapper() + .enable(SerializationFeature.INDENT_OUTPUT) + .setDefaultPrettyPrinter(JsonSupport.JlamaPrettyPrinter.INSTANCE)) .build()); static { @@ -267,23 +264,22 @@ public String build() { Map args = new HashMap(); args.putAll(Map.of( - "messages", - messages.stream().map(Message::toMap).toList(), - "add_generation_prompt", - addGenerationPrompt, - "eos_token", - m.eosToken(), - "bos_token", - "")); // We add the BOS ourselves + "messages", + messages.stream().map(Message::toMap).toList(), + "add_generation_prompt", + addGenerationPrompt, + "eos_token", + m.eosToken(), + "bos_token", + "")); // We add the BOS ourselves if (tools != null) { args.put("tools", tools); } - RenderResult r = jinjava.renderForResult(template, args); + RenderResult r = jinjava.renderForResult(template, args); - if (r.hasErrors()) - logger.warn("Prompt template errors: " + r.getErrors()); + if (r.hasErrors()) logger.warn("Prompt template errors: " + r.getErrors()); return r.getOutput(); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java index a8acfe3..3682739 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors.prompt; import com.fasterxml.jackson.annotation.JsonPropertyOrder; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java index 6535063..feff579 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java @@ -1,14 +1,26 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors.prompt; import com.fasterxml.jackson.annotation.JsonPropertyOrder; - /** * Tool */ -@JsonPropertyOrder({ - Tool.JSON_PROPERTY_TYPE, - Tool.JSON_PROPERTY_FUNCTION}) +@JsonPropertyOrder({Tool.JSON_PROPERTY_TYPE, Tool.JSON_PROPERTY_FUNCTION}) public class Tool { public static final String JSON_PROPERTY_TYPE = "type"; private final String type = "function"; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java index 2cc0f77..cd666ab 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java @@ -1,8 +1,22 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors.prompt; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - import java.util.Map; public class ToolCall { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java index 7faa326..bf6c183 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java @@ -26,8 +26,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.util.*; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -126,7 +124,7 @@ public long[] encode(String rawSentence) { Long id = model.vocabLookup.get(c); if (id != null) { // we found this codepoint in vocab, add it as a token - //logger.debug("{} -> {}", c, id); + // logger.debug("{} -> {}", c, id); tokens.add(id); } else { if (model.byteFallback) { @@ -135,7 +133,7 @@ public long[] encode(String rawSentence) { byte[] chars = code.getBytes(StandardCharsets.UTF_8); for (int k = 0; k < chars.length; k++) { long token = encodeCharacterAsToken(chars[k]); - //logger.debug("byte {} -> {}", Byte.toUnsignedInt(chars[k]), token); + // logger.debug("byte {} -> {}", Byte.toUnsignedInt(chars[k]), token); tokens.add(token); } } else { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java index 5fe9e14..35da6ea 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java @@ -16,7 +16,6 @@ package com.github.tjake.jlama.safetensors.tokenizer; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; - import java.util.List; import java.util.Optional; @@ -59,7 +58,6 @@ public interface Tokenizer { */ Optional promptSupport(); - /** * Get the model for this tokenizer (expert mode) * @return tokenizer model diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java index dd13978..73f34db 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java @@ -190,7 +190,7 @@ static String[] split(java.util.regex.Pattern p, CharSequence input, int limit, Matcher m = p.matcher(input); // Add segments before each match found - while(m.find()) { + while (m.find()) { if (!matchLimited || matchCount < limit - 1) { if (index == 0 && index == m.start() && m.start() == m.end()) { // no empty leading substring included for zero-width match @@ -213,8 +213,7 @@ static String[] split(java.util.regex.Pattern p, CharSequence input, int limit, } // If no match was found, return this - if (index == 0) - return new String[] {input.toString()}; + if (index == 0) return new String[] {input.toString()}; // Add remaining segment if (!matchLimited || matchCount < limit) @@ -223,7 +222,7 @@ static String[] split(java.util.regex.Pattern p, CharSequence input, int limit, // Construct result int resultSize = matchList.size(); if (limit == 0) { - while (resultSize > 0 && matchList.get(resultSize-1).isEmpty()) { + while (resultSize > 0 && matchList.get(resultSize - 1).isEmpty()) { resultSize--; } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java index ca87977..68ee2fd 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java @@ -17,7 +17,6 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.TensorInfo; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import java.io.IOException; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java index a162964..fe104a7 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java @@ -17,14 +17,12 @@ import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import java.lang.foreign.MemorySegment; import java.nio.ByteOrder; import java.nio.ShortBuffer; -import java.util.Arrays; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorSpecies; @@ -143,7 +141,7 @@ public String toString() { } for (int i = 0; i < sample.length; i++) { - sample[i] = FloatConversions.bFloat16ToFloat32(b.get(i + (shape.first()/2))); + sample[i] = FloatConversions.bFloat16ToFloat32(b.get(i + (shape.first() / 2))); } StringBuffer sb2 = new StringBuffer(); @@ -154,7 +152,6 @@ public String toString() { } } - return "BFloat16BufferTensor{" + "name='" + name + '\'' + ", shape=" + shape + ",\n b=" diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java index 340344b..c72b4e7 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java @@ -16,7 +16,6 @@ package com.github.tjake.jlama.tensor; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java index 9bec6cc..ee45564 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java @@ -16,7 +16,6 @@ package com.github.tjake.jlama.tensor; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; @@ -24,9 +23,6 @@ import java.lang.foreign.MemorySegment; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import java.util.Arrays; -import java.util.List; - import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; import org.slf4j.Logger; @@ -138,8 +134,7 @@ public int getMemorySegmentOffset(int offset) { @Override public FloatVector getVector(VectorSpecies species, int... voffset) { int offset = getOffset(voffset); - return FloatVector.fromMemorySegment( - species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); + return FloatVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); } @Override @@ -166,9 +161,6 @@ public String toString() { } } - return "FloatBufferTensor{" + "name='" - + name + '\'' + " shape=" - + shape + ",\nb={" - + sb + "...}"; + return "FloatBufferTensor{" + "name='" + name + '\'' + " shape=" + shape + ",\nb={" + sb + "...}"; } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java index dbe2c01..c77c6a6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java @@ -16,7 +16,6 @@ package com.github.tjake.jlama.tensor; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java index 2d937c6..75b1134 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java @@ -19,7 +19,6 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java index 4cd63bc..5b5ef3d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java @@ -19,7 +19,6 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index e1336ae..addf1ee 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -96,7 +96,8 @@ public void batchDotProduct( Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M"); // Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N"); // This check breaks for GQA - // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K" + a.shape() + " " + b.shape() + " " + columnLength); + // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K" + a.shape() + " " + b.shape() + " " + // + columnLength); int M = a.shape().dim(0); int N = rowChunkSize; // b.shape().dim(0); @@ -1817,10 +1818,9 @@ public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int le public BFloat16BufferTensor quantizeBF16(FloatBufferTensor ft, final int offset, int length) { - //Need this till we have a proper quantization - https://github.com/pytorch/pytorch/blob/7c1fbc7fe9cb8ddd5c913b4b3a9e94d00cb055ee/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h#L47 - if (true) - return new BFloat16BufferTensor(ft); + // Need this till we have a proper quantization + https: // github.com/pytorch/pytorch/blob/7c1fbc7fe9cb8ddd5c913b4b3a9e94d00cb055ee/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h#L47 + if (true) return new BFloat16BufferTensor(ft); // Up to caller to release BFloat16BufferTensor qft = (BFloat16BufferTensor) TensorCache.instance.get(DType.BF16, ft.shape()); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java index 53612a5..57163af 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java @@ -26,7 +26,7 @@ public interface TensorOperations { ThreadLocal scratch = ThreadLocal.withInitial(() -> new FloatBufferTensor(TensorShape.one)); String name(); - + default int parallelSplitSize() { return 1; } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java index 5653301..f997f5a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/DebugSupport.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.util; import com.github.tjake.jlama.tensor.AbstractTensor; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java index 3fa3d45..f88aba8 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java @@ -20,8 +20,6 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.ObjectMapper; -import com.hubspot.jinjava.objects.serialization.PyishPrettyPrinter; - import java.io.IOException; /** @@ -43,7 +41,6 @@ public static String toJson(Object o) { } } - public static class JlamaPrettyPrinter extends DefaultPrettyPrinter { public static final JlamaPrettyPrinter INSTANCE = new JlamaPrettyPrinter(); diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java index ecc68b9..97f730c 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java @@ -24,9 +24,7 @@ import com.github.tjake.jlama.tensor.operations.util.MemorySegmentSupport; import com.github.tjake.jlama.util.MachineSpec; import com.github.tjake.jlama.util.RuntimeSupport; - import java.lang.foreign.MemorySegment; -import java.lang.foreign.ValueLayout; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,8 +32,7 @@ public class NativeTensorOperations implements TensorOperations { private static final Logger logger = LoggerFactory.getLogger(NativeTensorOperations.class); static { - if (!JarSupport.maybeLoadLibrary()) - System.loadLibrary("jlama"); + if (!JarSupport.maybeLoadLibrary()) System.loadLibrary("jlama"); } public static final int HAS_F16C = NativeSimd.HAS_F16C(); @@ -239,7 +236,9 @@ public void dotProductBatchChunk( MemorySegment[] tmp = MemorySegmentSupport.setupBatch( i -> r[i].getMemorySegment(), i -> b[i].getMemorySegment(), - i -> b[i] instanceof Q4ByteBufferTensor ? ((Q4ByteBufferTensor) b[i]).getBlockF().getMemorySegment() : MemorySegment.NULL, + i -> b[i] instanceof Q4ByteBufferTensor + ? ((Q4ByteBufferTensor) b[i]).getBlockF().getMemorySegment() + : MemorySegment.NULL, r.length); MemorySegment ra = tmp[0]; MemorySegment rb = tmp[1]; diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java index 50a0440..7781c6e 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java @@ -1,10 +1,23 @@ -// Generated by jextract - +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; import java.lang.foreign.*; -public class NativeSimd { +public class NativeSimd { /** * {@snippet : @@ -12,7 +25,7 @@ public class NativeSimd { * } */ public static int HAS_F16C() { - return (int)2L; + return (int) 2L; } /** * {@snippet : @@ -20,7 +33,7 @@ public static int HAS_F16C() { * } */ public static int HAS_AVX2() { - return (int)4L; + return (int) 4L; } /** * {@snippet : @@ -28,7 +41,7 @@ public static int HAS_AVX2() { * } */ public static int IS_M_SERIES_MAC() { - return (int)8L; + return (int) 8L; } /** * {@snippet : @@ -36,7 +49,7 @@ public static int IS_M_SERIES_MAC() { * } */ public static int Q8_BLOCK_SIZE() { - return (int)32L; + return (int) 32L; } /** * {@snippet : @@ -44,7 +57,7 @@ public static int Q8_BLOCK_SIZE() { * } */ public static int Q4_BLOCK_SIZE() { - return (int)32L; + return (int) 32L; } /** @@ -52,8 +65,26 @@ public static int Q4_BLOCK_SIZE() { * void gemm_q8_q4(int flags, float* af, char* a, int aoffset, float* bf, char* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc); * } */ - public static void gemm_q8_q4(int flags, MemorySegment af, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc) { - throw new UnsupportedOperationException("Not implemented for this JDK version"); + public static void gemm_q8_q4( + int flags, + MemorySegment af, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldaf, + int ldb, + int ldbf, + int ldc) { + throw new UnsupportedOperationException("Not implemented for this JDK version"); } /** @@ -61,7 +92,26 @@ public static void gemm_q8_q4(int flags, MemorySegment af, MemorySegment a, int * void gemm_q8_q4_batch(int flags, int batch_num, float* af, char* a, int aoffset, float** bf, char** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc); * } */ - public static void gemm_q8_q4_batch(int flags, int batch_num, MemorySegment af, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc) { + public static void gemm_q8_q4_batch( + int flags, + int batch_num, + MemorySegment af, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldaf, + int ldb, + int ldbf, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -70,7 +120,21 @@ public static void gemm_q8_q4_batch(int flags, int batch_num, MemorySegment af, * void gemm_f32(int flags, float* a, int aoffset, float* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -79,7 +143,22 @@ public static void gemm_f32(int flags, MemorySegment a, int aoffset, MemorySegme * void gemm_f32_batch(int flags, int batch_num, float* a, int aoffset, float** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -88,7 +167,23 @@ public static void gemm_f32_batch(int flags, int batch_num, MemorySegment a, int * void gemm_f32_q4(int flags, float* a, int aoffset, float* bf, char* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc); * } */ - public static void gemm_f32_q4(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc) { + public static void gemm_f32_q4( + int flags, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldbf, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -97,7 +192,24 @@ public static void gemm_f32_q4(int flags, MemorySegment a, int aoffset, MemorySe * void gemm_f32_q4_batch(int flags, int batch_num, float* a, int aoffset, float** bf, char** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc); * } */ - public static void gemm_f32_q4_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc) { + public static void gemm_f32_q4_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldbf, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -106,7 +218,22 @@ public static void gemm_f32_q4_batch(int flags, int batch_num, MemorySegment a, * void gemm_bf16(int flags, short* a, int aoffset, short* b, int boffset, short* cr, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_bf16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_bf16( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -115,7 +242,23 @@ public static void gemm_bf16(int flags, MemorySegment a, int aoffset, MemorySegm * void gemm_bf16_batch(int flags, int batch_num, short* a, int aoffset, short** b, int boffset, short** cr, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_bf16_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_bf16_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -124,7 +267,22 @@ public static void gemm_bf16_batch(int flags, int batch_num, MemorySegment a, in * void gemm_f32_bf16(int flags, float* a, int aoffset, short* b, int boffset, short* cr, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_bf16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_bf16( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } @@ -133,9 +291,23 @@ public static void gemm_f32_bf16(int flags, MemorySegment a, int aoffset, Memory * void gemm_f32_bf16_batch(int flags, int batch_num, float* a, int aoffset, short** b, int boffset, short** cr, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_bf16_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_bf16_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { throw new UnsupportedOperationException("Not implemented for this JDK version"); } } - - diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/MemorySegmentSupport.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/MemorySegmentSupport.java index 0ac0883..e0358cd 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/MemorySegmentSupport.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/MemorySegmentSupport.java @@ -1,10 +1,30 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.util; import java.lang.foreign.*; import java.util.function.Function; public class MemorySegmentSupport { - public static MemorySegment[] setupBatch(Function r, Function b, Function c, int limit) { - throw new UnsupportedOperationException("Not implemented for this JDK version: " + Runtime.version().toString()); + public static MemorySegment[] setupBatch( + Function r, + Function b, + Function c, + int limit) { + throw new UnsupportedOperationException( + "Not implemented for this JDK version: " + Runtime.version().toString()); } } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java index 402c311..76efb5f 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java @@ -114,7 +114,8 @@ public Generator.Response generate( int promptLength = encoded.length; if (useEOS) { - promptTokens[promptTokens.length - 1] = model.getConfig().eosTokens.getLast(); // Add EOS + promptTokens[promptTokens.length - 1] = + model.getConfig().eosTokens.getLast(); // Add EOS promptLength++; } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java index 97e3455..e9f18f7 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java @@ -21,7 +21,6 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.KvBufferCache; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.Pair; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Uninterruptibles; @@ -33,7 +32,6 @@ import io.grpc.stub.StreamObserver; import java.io.*; import java.lang.foreign.MemorySegment; -import java.lang.foreign.ValueLayout; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Optional; diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index b00afab..b617223 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -1,24 +1,37 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.openai; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.net.openai.model.*; - import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import jakarta.validation.Valid; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; -import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.*; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; - import java.io.IOException; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @RestController @Validated @@ -38,13 +51,10 @@ public class OpenAIChatService { @RequestMapping( method = RequestMethod.POST, value = "/chat/completions", - produces = { "application/json", "text/event-stream" }, - consumes = { "application/json" } - ) + produces = {"application/json", "text/event-stream"}, + consumes = {"application/json"}) Object createChatCompletion( - @RequestHeader Map headers, - @Valid @RequestBody CreateChatCompletionRequest request - ) { + @RequestHeader Map headers, @Valid @RequestBody CreateChatCompletionRequest request) { List messages = request.getMessages(); @@ -69,57 +79,62 @@ Object createChatCompletion( for (ChatCompletionRequestMessage m : messages) { if (m.getActualInstance() instanceof ChatCompletionRequestUserMessage) { - ChatCompletionRequestUserMessageContent content = m.getChatCompletionRequestUserMessage().getContent(); + ChatCompletionRequestUserMessageContent content = + m.getChatCompletionRequestUserMessage().getContent(); if (content.getActualInstance() instanceof String) { builder.addUserMessage(content.getString()); } else { - for (ChatCompletionRequestMessageContentPart p : content.getListChatCompletionRequestMessageContentPart()) { + for (ChatCompletionRequestMessageContentPart p : + content.getListChatCompletionRequestMessageContentPart()) { if (p.getActualInstance() instanceof ChatCompletionRequestMessageContentPartText) { - builder.addUserMessage(p.getChatCompletionRequestMessageContentPartText().getText()); + builder.addUserMessage(p.getChatCompletionRequestMessageContentPartText() + .getText()); } else { - //We don't support other types of content... yet... + // We don't support other types of content... yet... return new ResponseEntity<>(HttpStatus.NOT_IMPLEMENTED); } } } } else if (m.getActualInstance() instanceof ChatCompletionRequestSystemMessage) { - builder.addSystemMessage(m.getChatCompletionRequestSystemMessage().getContent()); + builder.addSystemMessage( + m.getChatCompletionRequestSystemMessage().getContent()); } else if (m.getActualInstance() instanceof ChatCompletionRequestAssistantMessage) { - builder.addAssistantMessage(m.getChatCompletionRequestAssistantMessage().getContent()); + builder.addAssistantMessage( + m.getChatCompletionRequestAssistantMessage().getContent()); } else { return new ResponseEntity<>(HttpStatus.NOT_IMPLEMENTED); } } - float temperature = request.getTemperature() == null ? 0.3f : request.getTemperature().floatValue(); + float temperature = request.getTemperature() == null + ? 0.3f + : request.getTemperature().floatValue(); int maxTokens = request.getMaxTokens() == null ? 1024 : request.getMaxTokens(); AtomicInteger index = new AtomicInteger(0); if (request.getStream() != null && request.getStream()) { SseEmitter emitter = new SseEmitter(); - CompletableFuture.supplyAsync( () -> model.generate(sessionId, builder.build(), temperature, maxTokens, false, - (t, f) -> { - try { - emitter.send( - new CreateChatCompletionStreamResponse() + CompletableFuture.supplyAsync( + () -> model.generate(sessionId, builder.build(), temperature, maxTokens, false, (t, f) -> { + try { + emitter.send(new CreateChatCompletionStreamResponse() .id(sessionId.toString()) .choices(List.of(new CreateChatCompletionStreamResponseChoicesInner() - .index(index.getAndIncrement()) - .delta(new ChatCompletionStreamResponseDelta() - .content(t)))) - ); - } catch (IOException e) { - emitter.completeWithError(e); - } - })) + .index(index.getAndIncrement()) + .delta(new ChatCompletionStreamResponseDelta().content(t))))); + } catch (IOException e) { + emitter.completeWithError(e); + } + })) .handle((r, ex) -> { try { emitter.send(new CreateChatCompletionStreamResponse() .id(sessionId.toString()) .choices(List.of(new CreateChatCompletionStreamResponseChoicesInner() - .finishReason(CreateChatCompletionStreamResponseChoicesInner.FinishReasonEnum.STOP))) - ); + .finishReason( + CreateChatCompletionStreamResponseChoicesInner.FinishReasonEnum + .STOP)))); emitter.complete(); } catch (IOException e) { @@ -130,10 +145,9 @@ Object createChatCompletion( }); return emitter; - } - else - { - Generator.Response r = model.generate(sessionId, builder.build(), temperature, maxTokens, false, (s, f) -> {}); + } else { + Generator.Response r = + model.generate(sessionId, builder.build(), temperature, maxTokens, false, (s, f) -> {}); CreateChatCompletionResponse out = new CreateChatCompletionResponse() .id(sessionId.toString()) @@ -144,5 +158,4 @@ Object createChatCompletion( return new ResponseEntity<>(out, HttpStatus.OK); } } - -} \ No newline at end of file +} diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/ChatApiTest.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/ChatApiTest.java index 66f0f6d..ea60b14 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/ChatApiTest.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/ChatApiTest.java @@ -1,26 +1,35 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.openai; - import io.github.stefanbratanov.jvm.openai.*; import org.json.JSONException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.skyscreamer.jsonassert.JSONAssert; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.boot.test.web.server.LocalServerPort; -import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.ResponseEntity; -import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.junit.jupiter.SpringExtension; -import java.math.BigDecimal; -import java.util.List; - @ExtendWith(SpringExtension.class) -@SpringBootTest(classes = MockedOpenAIServer.class, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, useMainMethod = SpringBootTest.UseMainMethod.ALWAYS) +@SpringBootTest( + classes = MockedOpenAIServer.class, + webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, + useMainMethod = SpringBootTest.UseMainMethod.ALWAYS) public class ChatApiTest { @LocalServerPort private int port; @@ -38,18 +47,17 @@ public void testChatCompletion() throws JSONException { ChatClient client = openAI.chatClient(); - CreateChatCompletionRequest request = CreateChatCompletionRequest.newBuilder() - .model(OpenAIModel.GPT_3_5_TURBO) - .stream(false) - .temperature(0.0f) - .message(ChatMessage.userMessage("Who won the world series in 2020?")) - .build(); + CreateChatCompletionRequest request = + CreateChatCompletionRequest.newBuilder().model(OpenAIModel.GPT_3_5_TURBO).stream(false) + .temperature(0.0f) + .message(ChatMessage.userMessage("Who won the world series in 2020?")) + .build(); ChatCompletion response = client.createChatCompletion(request); System.err.println(response); - //JSONAssert.assertEquals(null, response.getBody(), false); + // JSONAssert.assertEquals(null, response.getBody(), false); } @Test @@ -61,17 +69,15 @@ public void testStreamingChatCompletion() throws JSONException { ChatClient client = openAI.chatClient(); - CreateChatCompletionRequest request = CreateChatCompletionRequest.newBuilder() - .model(OpenAIModel.GPT_3_5_TURBO) - .stream(true) - .temperature(0.0f) - .message(ChatMessage.userMessage("Who won the world series in 2020?")) - .build(); + CreateChatCompletionRequest request = + CreateChatCompletionRequest.newBuilder().model(OpenAIModel.GPT_3_5_TURBO).stream(true) + .temperature(0.0f) + .message(ChatMessage.userMessage("Who won the world series in 2020?")) + .build(); client.streamChatCompletion(request).forEach(System.err::println); - - //JSONAssert.assertEquals(null, response.getBody(), false); + // JSONAssert.assertEquals(null, response.getBody(), false); } private String createURLWithPort(String uri) { diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/MockedOpenAIServer.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/MockedOpenAIServer.java index 3cd283d..34f2bb1 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/MockedOpenAIServer.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/MockedOpenAIServer.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.openai; import com.github.tjake.jlama.model.AbstractModel; @@ -5,20 +20,20 @@ import com.github.tjake.jlama.net.JlamaServiceTest; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import java.io.File; +import java.io.IOException; import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import java.io.File; -import java.io.IOException; - @SpringBootApplication @SpringBootConfiguration @Configuration public class MockedOpenAIServer { - private final JlamaServiceTest.MockConfig modelConfig = new JlamaServiceTest.MockConfig(128, 4096, 8192, 16, 12, 1e5f); + private final JlamaServiceTest.MockConfig modelConfig = + new JlamaServiceTest.MockConfig(128, 4096, 8192, 16, 12, 1e5f); public static void main(String[] args) { SpringApplication.run(MockedOpenAIServer.class, args); diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/OpenAIServiceTests.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/OpenAIServiceTests.java index f17c417..520b924 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/OpenAIServiceTests.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/openai/OpenAIServiceTests.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.openai; import org.junit.jupiter.api.Test; @@ -10,6 +25,5 @@ public class OpenAIServiceTests { @Test - public void contextLoads() { - } -} \ No newline at end of file + public void contextLoads() {} +} diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java index e9bc40b..b2fd762 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java @@ -32,9 +32,7 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.concurrent.ThreadLocalRandom; - import org.junit.Assert; import org.junit.Assume; import org.junit.Test; @@ -374,17 +372,25 @@ public void testPromptSupportWithTools() { builder.addGenerationPrompt(true); Tool t = Tool.from(Function.builder() - .name("get_temperature") - .description("Simulates getting the current temperature at a location.") - .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) - .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) - .build()); + .name("get_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter( + "location", + "string", + "The location to get the temperature for, in the format \"City, Country\".", + true) + .addParameter( + "unit", + "string", + "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", + true) + .build()); builder.addTools(t); - //builder.addToolCall(new ToolCall("get_temperature", Map.of("location", "paris, france", "unit", "celsius"))); + // builder.addToolCall(new ToolCall("get_temperature", Map.of("location", "paris, france", "unit", "celsius"))); - //builder.addToolResult(Result.from("get_temperature", "", Map.of("temperature", 25.0, "unit", "celsius"))); + // builder.addToolResult(Result.from("get_temperature", "", Map.of("temperature", 25.0, "unit", "celsius"))); String prompt = builder.build(); @@ -392,20 +398,18 @@ public void testPromptSupportWithTools() { long[] encoded = tokenizer.encode(prompt); long[] expected = new long[] { - 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790, - 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 2744, 6013, - 439, 264, 55066, 128009, 128006, 882, 128007, 271, 22818, 279, 2768, 5865, 11, 4587, 6013, - 449, 264, 4823, 369, 264, 734, 1650, 449, 1202, 6300, 6105, 430, 1888, 11503, 279, 2728, 10137, - 382, 66454, 304, 279, 3645, 5324, 609, 794, 734, 836, 11, 330, 14105, 794, 11240, 315, 5811, - 836, 323, 1202, 907, 7966, 5519, 539, 1005, 7482, 382, 5018, 1337, 794, 330, 1723, 498, 330, - 1723, 794, 5324, 609, 794, 330, 456, 54625, 498, 330, 4789, 794, 330, 14354, 24031, 3794, 279, - 1510, 9499, 520, 264, 3813, 10684, 330, 14105, 794, 5324, 1337, 794, 330, 1735, 498, 330, 13495, - 794, 5324, 2588, 794, 5324, 1337, 794, 330, 928, 498, 330, 4789, 794, 330, 791, 3813, 311, 636, - 279, 9499, 369, 11, 304, 279, 3645, 330, 13020, 11, 14438, 66820, 2186, 330, 3928, 794, 5324, - 1337, 794, 330, 928, 498, 330, 4789, 794, 330, 791, 5089, 311, 471, 279, 9499, 304, 320, 68, 1326, - 2637, 330, 66, 41347, 498, 330, 69, 49010, 1865, 9388, 2186, 330, 6413, 794, 4482, 2588, 498, 330, - 3928, 1365, 3500, 633, 3923, 374, 279, 9282, 304, 41958, 1314, 1457, 30, 128009, 128006, 78191, - 128007, 271 + 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, + 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 2744, 6013, 439, 264, 55066, 128009, + 128006, 882, 128007, 271, 22818, 279, 2768, 5865, 11, 4587, 6013, 449, 264, 4823, 369, 264, 734, 1650, 449, + 1202, 6300, 6105, 430, 1888, 11503, 279, 2728, 10137, 382, 66454, 304, 279, 3645, 5324, 609, 794, 734, 836, + 11, 330, 14105, 794, 11240, 315, 5811, 836, 323, 1202, 907, 7966, 5519, 539, 1005, 7482, 382, 5018, 1337, + 794, 330, 1723, 498, 330, 1723, 794, 5324, 609, 794, 330, 456, 54625, 498, 330, 4789, 794, 330, 14354, + 24031, 3794, 279, 1510, 9499, 520, 264, 3813, 10684, 330, 14105, 794, 5324, 1337, 794, 330, 1735, 498, 330, + 13495, 794, 5324, 2588, 794, 5324, 1337, 794, 330, 928, 498, 330, 4789, 794, 330, 791, 3813, 311, 636, 279, + 9499, 369, 11, 304, 279, 3645, 330, 13020, 11, 14438, 66820, 2186, 330, 3928, 794, 5324, 1337, 794, 330, + 928, 498, 330, 4789, 794, 330, 791, 5089, 311, 471, 279, 9499, 304, 320, 68, 1326, 2637, 330, 66, 41347, + 498, 330, 69, 49010, 1865, 9388, 2186, 330, 6413, 794, 4482, 2588, 498, 330, 3928, 1365, 3500, 633, 3923, + 374, 279, 9282, 304, 41958, 1314, 1457, 30, 128009, 128006, 78191, 128007, 271 }; String out = tokenizer.decode(encoded); @@ -416,59 +420,59 @@ public void testPromptSupportWithTools() { @Test public void testMistralTools() { - String modelPrefix = "../models/Mistral-7B-Instruct-v0.3"; - Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + String modelPrefix = "../models/Mistral-7B-Instruct-v0.3"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); - Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); - PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); - builder.addSystemMessage("You always respond as a pirate"); - builder.addUserMessage("What is the weather in paris right now?"); - builder.addGenerationPrompt(true); + Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); + PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); + builder.addSystemMessage("You always respond as a pirate"); + builder.addUserMessage("What is the weather in paris right now?"); + builder.addGenerationPrompt(true); - Tool t = Tool.from(Function.builder() - .name("get_current_temperature") - .description("Simulates getting the current temperature at a location.") - .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) - .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) - .build()); + Tool t = Tool.from(Function.builder() + .name("get_current_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter( + "location", + "string", + "The location to get the temperature for, in the format \"City, Country\".", + true) + .addParameter( + "unit", + "string", + "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", + true) + .build()); - builder.addTools(t); + builder.addTools(t); String prompt = builder.build(); long[] encoded = tokenizer.encode(prompt); - long[] official = new long[]{ - 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, - 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, - 29475, 17329, 1316, 1113, 7286, 2032, 1113, 8322, 27653, 3487, - 1040, 2636, 8409, 1206, 1032, 5491, 9959, 1113, 12206, 2032, - 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, - 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, - 2032, 1113, 1782, 5491, 1066, 1393, 1040, 8409, 1122, 29493, - 1065, 1040, 5800, 12547, 22781, 29493, 13776, 5651, 1379, 1649, - 1113, 6074, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, - 7286, 2032, 1113, 1782, 5796, 1066, 1372, 1040, 8409, 1065, - 1093, 29474, 29491, 29489, 2831, 12547, 29485, 1958, 3938, 23102, - 12547, 29490, 19425, 13075, 29524, 4913, 29507, 11549, 1113, 11661, - 2032, 8135, 3501, 1316, 1113, 6074, 3010, 1743, 10925, 7, - 3, 1763, 2511, 10189, 1158, 1032, 18136, 1148, 781, 781, - 3963, 1117, 1040, 8854, 1065, 1708, 1046, 1871, 1823, 29572, - 4 + long[] official = new long[] { + 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, + 29498, 29475, 17329, 1316, 1113, 7286, 2032, 1113, 8322, 27653, 3487, 1040, 2636, 8409, 1206, 1032, 5491, + 9959, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, + 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 5491, 1066, 1393, 1040, 8409, 1122, 29493, 1065, + 1040, 5800, 12547, 22781, 29493, 13776, 5651, 1379, 1649, 1113, 6074, 2032, 10598, 1891, 2032, 1113, 2195, + 1316, 1113, 7286, 2032, 1113, 1782, 5796, 1066, 1372, 1040, 8409, 1065, 1093, 29474, 29491, 29489, 2831, + 12547, 29485, 1958, 3938, 23102, 12547, 29490, 19425, 13075, 29524, 4913, 29507, 11549, 1113, 11661, 2032, + 8135, 3501, 1316, 1113, 6074, 3010, 1743, 10925, 7, 3, 1763, 2511, 10189, 1158, 1032, 18136, 1148, 781, 781, + 3963, 1117, 1040, 8854, 1065, 1708, 1046, 1871, 1823, 29572, 4 }; for (int i = 0; i < official.length; i++) { if (official[i] != encoded[i]) { System.out.println(i + " " + official[i] + " " + encoded[i]); - } - else - System.out.println(i + " " + official[i]); + } else System.out.println(i + " " + official[i]); } Assert.assertArrayEquals(official, encoded); - String expected = "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"name\": \"get_current_temperature\", \"description\": \"Simulates getting the current temperature at a location.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"}, \"unit\": {\"type\": \"string\", \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"}}, \"required\": [\"location\", \"unit\"]}}}][/AVAILABLE_TOOLS][INST] You always respond as a pirate\n" + - "\n" + - "What is the weather in paris right now?[/INST]"; + String expected = + "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"name\": \"get_current_temperature\", \"description\": \"Simulates getting the current temperature at a location.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"}, \"unit\": {\"type\": \"string\", \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"}}, \"required\": [\"location\", \"unit\"]}}}][/AVAILABLE_TOOLS][INST] You always respond as a pirate\n" + + "\n" + + "What is the weather in paris right now?[/INST]"; Assert.assertEquals(expected, prompt); } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index 1a6da05..fef5a34 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -114,17 +114,24 @@ public void LlamaRun() throws Exception { builder.addGenerationPrompt(true); Tool t = Tool.from(Function.builder() - .name("get_current_temperature") - .description("Simulates getting the current temperature at a location.") - .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) - .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) - .build()); + .name("get_current_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter( + "location", + "string", + "The location to get the temperature for, in the format \"City, Country\".", + true) + .addParameter( + "unit", + "string", + "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", + true) + .build()); builder.addTools(t); logger.info("First prompt \n{}", builder.build()); - Generator.Response r = - model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, (l,f) ->{}); + Generator.Response r = model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, (l, f) -> {}); logger.info("Response: {}", r.text); if (r.finishReason == Generator.FinishReason.STOP_TOKEN) { @@ -144,9 +151,10 @@ public void LlamaRun() throws Exception { builder.addToolCall(f); builder.addToolResult(Result.from(f.getName(), null, 20f)); logger.info("Second prompt {}", builder.build()); - Generator.Response r2 = model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, (l,p) ->{}); + Generator.Response r2 = + model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, (l, p) -> {}); - Assert.assertTrue( r2.text, r2.text.contains("20")); + Assert.assertTrue(r2.text, r2.text.contains("20")); logger.info("Response: {}", r2.text); } else { Assert.fail(); @@ -190,15 +198,24 @@ public void MistralRun() throws Exception { Tool t = Tool.from(Function.builder() .name("get_current_temperature") .description("Simulates getting the current temperature at a location.") - .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) - .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) + .addParameter( + "location", + "string", + "The location to get the temperature for, in the format \"City, Country\".", + true) + .addParameter( + "unit", + "string", + "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", + true) .build()); builder.addTools(t); logger.info("First prompt \n{}", builder.build()); - Generator.Response r = model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, makeOutHandler()); + Generator.Response r = + model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, makeOutHandler()); logger.info("Response: {}", r.text); } @@ -219,11 +236,8 @@ public void MixtralRun() throws Exception { + "or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance. Explain the above in one sentence:"; String prompt = "Tell me a joke."; - String p = model.promptSupport() - .get() - .builder() - .addUserMessage(prompt) - .build(); + String p = + model.promptSupport().get().builder().addUserMessage(prompt).build(); model.generate(UUID.randomUUID(), p, 0.7f, 256, true, makeOutHandler()); } @@ -239,11 +253,8 @@ public void GemmaRun() throws Exception { GemmaConfig c = om.readValue(new File(modelPrefix + "/config.json"), GemmaConfig.class); GemmaModel model = new GemmaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); String prompt = "Tell me a joke."; - String p = model.promptSupport() - .get() - .builder() - .addUserMessage(prompt) - .build(); + String p = + model.promptSupport().get().builder().addUserMessage(prompt).build(); model.generate(UUID.randomUUID(), p, 0.3f, 256, false, makeOutHandler()); } } @@ -355,14 +366,14 @@ private BiConsumer makeOutHandler() { PrintWriter out; BiConsumer outCallback; - AtomicInteger i = new AtomicInteger(0); - StringBuilder b = new StringBuilder(); - out = new PrintWriter(System.out); - outCallback = (w, t) -> { - b.append(w); - out.println(String.format("%d: %s [took %.2fms])", i.getAndIncrement(), b, t)); - out.flush(); - }; + AtomicInteger i = new AtomicInteger(0); + StringBuilder b = new StringBuilder(); + out = new PrintWriter(System.out); + outCallback = (w, t) -> { + b.append(w); + out.println(String.format("%d: %s [took %.2fms])", i.getAndIncrement(), b, t)); + out.flush(); + }; return outCallback; }