From 12f7c1449948a2d83561e64c5bd756a3fd23473c Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Wed, 16 Oct 2024 22:56:34 -0400 Subject: [PATCH] Set the max tokens based on the model and fix temp for now --- .../jlama/model/functions/Generator.java | 3 ++ .../operations/PanamaTensorOperations.java | 37 ++++++++++++++++++- .../jlama/net/openai/OpenAIChatService.java | 13 +++---- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java index f785015..bf9efeb 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java @@ -15,6 +15,7 @@ */ package com.github.tjake.jlama.model.functions; +import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.prompt.ToolCall; @@ -163,6 +164,8 @@ default Map classify(String input, PoolingType poolingType) { throw new UnsupportedOperationException("Classification not supported by this model"); } + Config getConfig(); + Tokenizer getTokenizer(); Optional promptSupport(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 2faf5ce..c265724 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -2170,8 +2170,18 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, throw new UnsupportedOperationException(); } break; + case BF16: + switch (vectorType) { + case AVX_512: + case AVX_256: + accumulateF32BF16_256((FloatBufferTensor) a, (BFloat16BufferTensor) b, offset, limit); + break; + default: + throw new UnsupportedOperationException(); + } + break; default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("F32 => " + b.dType()); } break; case BF16: @@ -2244,6 +2254,31 @@ void accumulateF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, } } + void accumulateF32BF16_256(FloatBufferTensor a, BFloat16BufferTensor b, int offset, int limit) { + int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit); + + int i = offset; + for (; i < upperBound; i += FloatVector.SPECIES_256.length()) { + + // F32 + var af = a.getVector(FloatVector.SPECIES_256, 0, i); + + // Convert BF16 to F32 + var bf = b.getVector(ShortVector.SPECIES_128, 0, i) + .convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0) + .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT_256) + .reinterpretAsFloats(); + + var res = af.add(bf); + a.intoTensor(res, 0, i); + } + + // tail + for (; i < offset + limit; i++) { + a.set(a.get(0, i) + b.get(0, i), 0, i); + } + } + void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) { int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit); diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index 9343d58..a669b42 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -101,9 +101,10 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ } } - float temperature = request.getTemperature() == null ? 0.3f : request.getTemperature().floatValue(); - int maxTokens = request.getMaxTokens() == null ? 1024 : request.getMaxTokens(); + float temperature = 0.3f; + int maxTokens = request.getMaxTokens() == null ? model.getConfig().contextLength : request.getMaxTokens(); + logger.info("Generating completion for session {} with temperature {} and max tokens {}", sessionId, temperature, maxTokens); AtomicInteger index = new AtomicInteger(0); if (request.getStream() != null && request.getStream()) { SseEmitter emitter = new SseEmitter(-1L); @@ -139,11 +140,9 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ emitter.complete(); - logger.info( - "Stats: {} ms/tok (prompt), {} ms/tok (gen)", - Math.round(r.promptTimeMs / (double) r.promptTokens), - Math.round(r.generateTimeMs / (double) r.generatedTokens) - ); + logger.info("{} tokens/s (prompt), {} tokens/s (gen)", + Math.round(r.promptTokens / (double) (r.promptTimeMs / 1000f)), + Math.round(r.generatedTokens / (double) (r.generateTimeMs / 1000f))); } catch (IOException e) { emitter.completeWithError(e);