diff --git a/README.md b/README.md index 67ac26a..5f6631f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Implements: * Flash Attention * Mixture of Experts * Huggingface [SafeTensors](https://github.com/huggingface/safetensors) model and tokenizer format - * Support for F32, F16, BF16 models + * Support for F32, F16, BF16 types * Support for Q8, Q4 model quantization * Fast GEMM operations * Distributed Inference! diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java index 7020da9..aaade07 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/FloatConversions.java @@ -33,7 +33,7 @@ public static float bFloat16ToFloat32(short raw) { } public static short float32ToBFloat16(float n) { - //if (true) + // if (true) // return (short) ((Float.floatToRawIntBits(n) >> 16) & 0xffff); int nbits = Float.floatToRawIntBits(n); // 32 bits has 1 sign bit, 8 exponent bits, 23 mantissa bits diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java index ae9a038..1df41a2 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java @@ -19,10 +19,9 @@ import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.BiIntConsumer; import com.github.tjake.jlama.util.PhysicalCoreExecutor; +import com.google.common.base.Preconditions; import java.util.function.IntConsumer; import java.util.stream.IntStream; - -import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,9 +32,7 @@ public class VectorMath { public static void pfor(int start, int end, IntConsumer action) { PhysicalCoreExecutor.instance .get() - .execute(() -> IntStream.range(start, end) - .parallel() - .forEach(action)); + .execute(() -> IntStream.range(start, end).parallel().forEach(action)); } public static void pchunk(int offset, int length, BiIntConsumer action) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 8364abe..76a4f05 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -32,7 +32,6 @@ import com.github.tjake.jlama.util.Pair; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; - import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -255,8 +254,6 @@ public int sample(AbstractTensor output, float temperature, float uniformSample, double maxv = Double.NEGATIVE_INFINITY; for (int i = 0; i < c.vocabularySize; i++) { float v = logits.get(0, i); - //v = (float) (30.0f * Math.tanh(v / 30.0f)); - //logits.set(v, 0, i); if (v > maxv) { maxi = i; maxv = v; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index aa7cbec..6faa0da 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -18,7 +18,6 @@ import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.operations.TensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import java.util.*; @@ -122,9 +121,9 @@ public AbstractTensor forward( int batchSize = input.shape().first(); try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, c.embeddingLength); - AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength); - AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength); - AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) { + AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength); + AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength); + AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) { if (c.isGQA) { VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> { @@ -203,10 +202,10 @@ public AbstractTensor forward( AbstractTensor value = valueBatch.slice(bi); if (key.dType() != tmpKey.dType()) { - try (AbstractTensor tmpKey2 = TensorOperationsProvider.get() - .quantize(tmpKey, key.dType(), 0, c.kvLength); - AbstractTensor tmpVal2 = TensorOperationsProvider.get() - .quantize(tmpVal, val.dType(), 0, c.kvLength)) { + try (AbstractTensor tmpKey2 = + TensorOperationsProvider.get().quantize(tmpKey, key.dType(), 0, c.kvLength); + AbstractTensor tmpVal2 = + TensorOperationsProvider.get().quantize(tmpVal, val.dType(), 0, c.kvLength)) { key.copyFrom( tmpKey2, tmpKey2.getOffset(0, c.kvSegmentStart()), @@ -292,7 +291,7 @@ public AbstractTensor forward( } }); - //Attention + // Attention VectorMath.pfor(c.headStart(), c.headEnd(), h -> { try (AbstractTensor attn = m.makeFullTensor(1, kvp.shape().first())) { int xoffset = c.maybeMapToGroupHead(h) * c.headSize; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java index 7fdedae..e38c396 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/RMSNorm.java @@ -51,7 +51,7 @@ public AbstractTensor forward( } if (reducer.isPresent()) { - Pair p = reducer.get().apply((float)ss, 0f); + Pair p = reducer.get().apply((float) ss, 0f); ss = p.left; } @@ -60,7 +60,7 @@ public AbstractTensor forward( ss = (1.0 / StrictMath.sqrt(ss)); // normalize and scale for (int j = offset; j < limit; j++) { - output.set((weightAdjustment + weights.get(0, j)) * ((float)ss * input.get(b, j)), b, j); + output.set((weightAdjustment + weights.get(0, j)) * ((float) ss * input.get(b, j)), b, j); } } return output; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java index 5896669..c370421 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaModel.java @@ -30,7 +30,6 @@ import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.operations.PanamaTensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import java.util.Optional; import java.util.stream.IntStream; @@ -63,7 +62,8 @@ public GemmaModel( super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType); // https://github.com/huggingface/transformers/blob/1082361a1978d30db5c3932d1ee08914d74d9697/src/transformers/models/gemma/modeling_gemma.py#L898 // This is the scaling factor for the embedding layer but google's implementation is a is rounded to 16 bits - this.embeddingScalingFactor = FloatConversions.bFloat16ToFloat32(FloatConversions.float32ToBFloat16((float) Math.pow(c.embeddingLength, 0.5))); + this.embeddingScalingFactor = FloatConversions.bFloat16ToFloat32( + FloatConversions.float32ToBFloat16((float) Math.pow(c.embeddingLength, 0.5))); } private AbstractTensor wte; @@ -126,7 +126,8 @@ protected EmbedInput loadInputWeights() { AbstractTensor embedding = makeTensor(c.embeddingLength); AbstractTensor at = wte.slice(true, inputToken); if (wte.dType() != embedding.dType()) - at = TensorOperationsProvider.get().quantize(at, embedding.dType(), c.embeddingSegmentStart(), c.embeddingSegmentLength()); + at = TensorOperationsProvider.get() + .quantize(at, embedding.dType(), c.embeddingSegmentStart(), c.embeddingSegmentLength()); embedding.copyFrom( at, @@ -135,7 +136,8 @@ protected EmbedInput loadInputWeights() { c.embeddingSegmentLength()); // This is important for Gemma, but not for Llama - TensorOperationsProvider.get().scale(embeddingScalingFactor, embedding, c.embeddingSegmentStart(), c.embeddingSegmentLength()); + TensorOperationsProvider.get() + .scale(embeddingScalingFactor, embedding, c.embeddingSegmentStart(), c.embeddingSegmentLength()); return embedding; }; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaTokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaTokenizer.java index e7f6994..00cf2fc 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaTokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaTokenizer.java @@ -48,7 +48,7 @@ protected Optional maybeDecodeTokenAsCharacter(long id) { @Override protected String preProcess(String sentence) { - sentence = sentence.replace(" ", SPIECE_UNDERLINE); + sentence = sentence.replace(" ", SPIECE_UNDERLINE); return sentence; } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java index 8baa7e7..6e9765e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java @@ -18,12 +18,11 @@ import com.hubspot.jinjava.Jinjava; import com.hubspot.jinjava.JinjavaConfig; import com.hubspot.jinjava.lib.fn.ELFunctionDefinition; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * This class also renders the prompt templates of the huggingface model format (using jinja templates) @@ -39,7 +38,9 @@ public class PromptSupport { .build()); static { - jinjava.getGlobalContext().registerFunction(new ELFunctionDefinition("", "raise_exception", PromptSupport.class, "raiseException", String.class)); + jinjava.getGlobalContext() + .registerFunction(new ELFunctionDefinition( + "", "raise_exception", PromptSupport.class, "raiseException", String.class)); } private final TokenizerModel m; @@ -150,8 +151,7 @@ public String build() { "eos_token", m.eosToken(), "bos_token", - m.bosToken() - )); + m.bosToken())); } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java index e48243e..02b51e6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java @@ -288,7 +288,8 @@ public AbstractTensor quantize(DType dType) { public AbstractTensor quantize(DType dType, boolean force) { - if (!force && (this.shape().first() == 1 || this.dType == dType || this.dType.size() < dType.size())) return this; + if (!force && (this.shape().first() == 1 || this.dType == dType || this.dType.size() < dType.size())) + return this; if (shape.isSparse()) { logger.info("Quantizing sparse tensor is not supported"); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java index 0450114..b6b73b1 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java @@ -91,7 +91,7 @@ private Pair makeKvBuffer(UUID session) { .order(ByteOrder.LITTLE_ENDIAN) .asFloatBuffer(); - t = new FloatBufferTensor(fb, s, true); + t = new FloatBufferTensor(fb, s, true); } else if (model.getWorkingDType() == DType.BF16) { ShortBuffer sb = raf.getChannel() .map(FileChannel.MapMode.READ_WRITE, 0, bytes) @@ -103,7 +103,6 @@ private Pair makeKvBuffer(UUID session) { throw new UnsupportedOperationException("Only F32/BF16 is supported for now"); } - return Pair.create(raf, t); } catch (IOException e) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java index e15a4d8..74395ed 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.java @@ -113,8 +113,7 @@ public void scale(float factor, AbstractTensor x, int offset, int length) { int limit = offset + length; for (int b = 0; b < x.shape().first(); b++) - for (int i = offset; i < limit; ++i) - x.set(x.get(b, i) * factor, b, i); + for (int i = offset; i < limit; ++i) x.set(x.get(b, i) * factor, b, i); } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 301b353..b61c01a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -42,7 +42,6 @@ public final class PanamaTensorOperations implements TensorOperations { static final IntVector BF16_BYTE_SHIFT = IntVector.broadcast(IntVector.SPECIES_PREFERRED, 16); - static final IntVector BF16_BYTE_SHIFT_512 = IntVector.broadcast(IntVector.SPECIES_512, 16); static final FloatVector F32_ROUND_UP_512 = FloatVector.broadcast(FloatVector.SPECIES_512, 0.5f); @@ -100,7 +99,7 @@ public void batchDotProduct( int rowChunkSize) { Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2); Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M"); - //Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N"); + // Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N"); // Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K"); int M = a.shape().dim(0); @@ -132,9 +131,11 @@ public void batchDotProduct( }; case BF16 -> switch (b.dType()) { case BF16 -> new GemmerBF16(K, a, b, result, aColumnOffset, bColumnOffset); - default -> throw new UnsupportedOperationException(b.dType().name()); + default -> throw new UnsupportedOperationException( + b.dType().name()); }; - default -> throw new UnsupportedOperationException(a.dType().name() + " " + b.dType().name()); + default -> throw new UnsupportedOperationException( + a.dType().name() + " " + b.dType().name()); }; gemm.matmul(0, M, bRowOffset, bRowOffset + N); @@ -1466,10 +1467,10 @@ protected int pickKernel(int m0, int m, int n0, int n) { nc = 4; kernel(m0, m, 1, n0, n, 4, matmul1x4); } else {*/ - mc = 1; - nc = 1; - kernel(m0, m, 1, n0, n, 1, matmul1x1); - //} + mc = 1; + nc = 1; + kernel(m0, m, 1, n0, n, 1, matmul1x1); + // } return (mc << 4) | nc; } @@ -1492,7 +1493,6 @@ protected BiIntConsumer initMatmul1x1() { .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT) .reinterpretAsFloats(); - ShortVector sb = b.getVector(ShortVector.SPECIES_PREFERRED, j, boffset); FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0) .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT) @@ -1717,7 +1717,8 @@ protected BiIntConsumer initMatmul1x1() { int slen = ShortVector.SPECIES_PREFERRED.length(); for (; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) { FloatVector va0 = a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset); - FloatVector va1 = a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset + FloatVector.SPECIES_PREFERRED.length()); + FloatVector va1 = a.getVector( + FloatVector.SPECIES_PREFERRED, i, aoffset + FloatVector.SPECIES_PREFERRED.length()); ShortVector sb = b.getVector(ShortVector.SPECIES_PREFERRED, j, boffset); FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0) @@ -1835,7 +1836,8 @@ public BFloat16BufferTensor quantizeBF16(FloatBufferTensor ft, final int offset, .lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT) .convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1); - VectorMask mask = VectorMask.fromLong(ShortVector.SPECIES_PREFERRED, (1L << FloatVector.SPECIES_PREFERRED.length()) - 1); + VectorMask mask = VectorMask.fromLong( + ShortVector.SPECIES_PREFERRED, (1L << FloatVector.SPECIES_PREFERRED.length()) - 1); mask = mask.not(); // Invert the mask to select the second half var r = r0.blend(r1, mask); @@ -2276,7 +2278,6 @@ public Q8ByteBufferTensor quantizeBF16_Q8_arm(BFloat16BufferTensor ft, int offse return qft; } - @Override public void maccumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, int limit) { Preconditions.checkArgument(aBatch.dType() == bBatch.dType()); @@ -2616,12 +2617,19 @@ void saxpyF32(float alpha, FloatBufferTensor x, FloatBufferTensor y, int xoffset } @Override - public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) { + public void saxpy( + AbstractTensor alpha, + AbstractTensor x, + AbstractTensor y, + int xoffset, + int yoffset, + int limit, + int batchSize) { Preconditions.checkArgument(limit % 2 == 0); switch (x.dType()) { case F32: - saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor)y, xoffset, yoffset, limit, batchSize); + saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit, batchSize); break; case BF16: switch (y.dType()) { @@ -2649,7 +2657,6 @@ public void saxpyF32( int limit, int batchSize) { - int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(limit); // Use Nearest multiple of 4 @@ -2724,8 +2731,7 @@ public void saxpyBF16F32( } } - void saxpyBF16( - float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, int aoffset, int boffset, int limit) { + void saxpyBF16(float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, int aoffset, int boffset, int limit) { int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit); Preconditions.checkArgument(upperBound == limit); @@ -2774,9 +2780,7 @@ void saxpyBF16( } } - - void saxpyBF16F32( - float alpha, BFloat16BufferTensor a, FloatBufferTensor b, int aoffset, int boffset, int limit) { + void saxpyBF16F32(float alpha, BFloat16BufferTensor a, FloatBufferTensor b, int aoffset, int boffset, int limit) { int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit); Preconditions.checkArgument(upperBound == limit); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java index 1265153..82c20ae 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperations.java @@ -122,7 +122,8 @@ default void saxpy( int yoffset, int limit, int batchSize) { - Preconditions.checkArgument(alpha.shape().last() == x.shape().first() && y.shape().first() == 1); + Preconditions.checkArgument( + alpha.shape().last() == x.shape().first() && y.shape().first() == 1); for (int i = 0; i < batchSize; i++) { saxpy(alpha.get(0, i), x.slice(i), y, xoffset, yoffset, limit); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/PhysicalCoreExecutor.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/PhysicalCoreExecutor.java index 0c611a8..0eed44c 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/PhysicalCoreExecutor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/PhysicalCoreExecutor.java @@ -25,7 +25,7 @@ */ public class PhysicalCoreExecutor { private static volatile int physicalCoreCount = - Math.max(1, Runtime.getRuntime().availableProcessors()/2); + Math.max(1, Runtime.getRuntime().availableProcessors() / 2); private static final AtomicBoolean started = new AtomicBoolean(false); /** diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java index 4883826..ad55867 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java @@ -1,13 +1,26 @@ -// Generated by jextract - +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.VarHandle; -import java.nio.ByteOrder; -import java.lang.foreign.*; import static java.lang.foreign.ValueLayout.*; -public class NativeSimd { + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; + +public class NativeSimd { public static final OfByte C_CHAR = JAVA_BYTE; public static final OfShort C_SHORT = JAVA_SHORT; @@ -23,7 +36,7 @@ public class NativeSimd { * } */ public static int HAS_F16C() { - return (int)2L; + return (int) 2L; } /** * {@snippet : @@ -31,7 +44,7 @@ public static int HAS_F16C() { * } */ public static int HAS_AVX2() { - return (int)4L; + return (int) 4L; } /** * {@snippet : @@ -39,7 +52,7 @@ public static int HAS_AVX2() { * } */ public static int IS_M_SERIES_MAC() { - return (int)8L; + return (int) 8L; } /** * {@snippet : @@ -47,7 +60,7 @@ public static int IS_M_SERIES_MAC() { * } */ public static int Q8_BLOCK_SIZE() { - return (int)32L; + return (int) 32L; } /** * {@snippet : @@ -55,17 +68,36 @@ public static int Q8_BLOCK_SIZE() { * } */ public static int Q4_BLOCK_SIZE() { - return (int)32L; + return (int) 32L; } + public static MethodHandle gemm_q8_q4$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$1,"gemm_q8_q4"); + return RuntimeHelper.requireNonNull(constants$0.const$1, "gemm_q8_q4"); } /** * {@snippet : * void gemm_q8_q4(int flags, float* af, char* a, int aoffset, float* bf, char* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc); * } */ - public static void gemm_q8_q4(int flags, MemorySegment af, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc) { + public static void gemm_q8_q4( + int flags, + MemorySegment af, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldaf, + int ldb, + int ldbf, + int ldc) { var mh$ = gemm_q8_q4$MH(); try { mh$.invokeExact(flags, af, a, aoffset, bf, b, boffset, r, roffset, m, n0, n, k, lda, ldaf, ldb, ldbf, ldc); @@ -73,31 +105,68 @@ public static void gemm_q8_q4(int flags, MemorySegment af, MemorySegment a, int throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_q8_q4_batch$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$3,"gemm_q8_q4_batch"); + return RuntimeHelper.requireNonNull(constants$0.const$3, "gemm_q8_q4_batch"); } /** * {@snippet : * void gemm_q8_q4_batch(int flags, int batch_num, float* af, char* a, int aoffset, float** bf, char** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc); * } */ - public static void gemm_q8_q4_batch(int flags, int batch_num, MemorySegment af, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldaf, int ldb, int ldbf, int ldc) { + public static void gemm_q8_q4_batch( + int flags, + int batch_num, + MemorySegment af, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldaf, + int ldb, + int ldbf, + int ldc) { var mh$ = gemm_q8_q4_batch$MH(); try { - mh$.invokeExact(flags, batch_num, af, a, aoffset, bf, b, boffset, r, roffset, m, n0, n, k, lda, ldaf, ldb, ldbf, ldc); + mh$.invokeExact( + flags, batch_num, af, a, aoffset, bf, b, boffset, r, roffset, m, n0, n, k, lda, ldaf, ldb, ldbf, + ldc); } catch (Throwable ex$) { throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$5,"gemm_f32"); + return RuntimeHelper.requireNonNull(constants$0.const$5, "gemm_f32"); } /** * {@snippet : * void gemm_f32(int flags, float* a, int aoffset, float* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_f32$MH(); try { mh$.invokeExact(flags, a, aoffset, b, boffset, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -105,15 +174,31 @@ public static void gemm_f32(int flags, MemorySegment a, int aoffset, MemorySegme throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32_batch$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$1,"gemm_f32_batch"); + return RuntimeHelper.requireNonNull(constants$1.const$1, "gemm_f32_batch"); } /** * {@snippet : * void gemm_f32_batch(int flags, int batch_num, float* a, int aoffset, float** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_f32_batch$MH(); try { mh$.invokeExact(flags, batch_num, a, aoffset, b, boffset, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -121,15 +206,32 @@ public static void gemm_f32_batch(int flags, int batch_num, MemorySegment a, int throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32_q4$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$3,"gemm_f32_q4"); + return RuntimeHelper.requireNonNull(constants$1.const$3, "gemm_f32_q4"); } /** * {@snippet : * void gemm_f32_q4(int flags, float* a, int aoffset, float* bf, char* b, int boffset, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc); * } */ - public static void gemm_f32_q4(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc) { + public static void gemm_f32_q4( + int flags, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldbf, + int ldc) { var mh$ = gemm_f32_q4$MH(); try { mh$.invokeExact(flags, a, aoffset, bf, b, boffset, r, roffset, m, n0, n, k, lda, ldb, ldbf, ldc); @@ -137,15 +239,33 @@ public static void gemm_f32_q4(int flags, MemorySegment a, int aoffset, MemorySe throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32_q4_batch$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$5,"gemm_f32_q4_batch"); + return RuntimeHelper.requireNonNull(constants$1.const$5, "gemm_f32_q4_batch"); } /** * {@snippet : * void gemm_f32_q4_batch(int flags, int batch_num, float* a, int aoffset, float** bf, char** b, int boffset, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc); * } */ - public static void gemm_f32_q4_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldbf, int ldc) { + public static void gemm_f32_q4_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment bf, + MemorySegment b, + int boffset, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldbf, + int ldc) { var mh$ = gemm_f32_q4_batch$MH(); try { mh$.invokeExact(flags, batch_num, a, aoffset, bf, b, boffset, r, roffset, m, n0, n, k, lda, ldb, ldbf, ldc); @@ -153,15 +273,31 @@ public static void gemm_f32_q4_batch(int flags, int batch_num, MemorySegment a, throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_bf16$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$1,"gemm_bf16"); + return RuntimeHelper.requireNonNull(constants$2.const$1, "gemm_bf16"); } /** * {@snippet : * void gemm_bf16(int flags, short* a, int aoffset, short* b, int boffset, short* cr, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_bf16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_bf16( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_bf16$MH(); try { mh$.invokeExact(flags, a, aoffset, b, boffset, cr, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -169,15 +305,32 @@ public static void gemm_bf16(int flags, MemorySegment a, int aoffset, MemorySegm throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_bf16_batch$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$3,"gemm_bf16_batch"); + return RuntimeHelper.requireNonNull(constants$2.const$3, "gemm_bf16_batch"); } /** * {@snippet : * void gemm_bf16_batch(int flags, int batch_num, short* a, int aoffset, short** b, int boffset, short** cr, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_bf16_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_bf16_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_bf16_batch$MH(); try { mh$.invokeExact(flags, batch_num, a, aoffset, b, boffset, cr, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -185,15 +338,31 @@ public static void gemm_bf16_batch(int flags, int batch_num, MemorySegment a, in throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32_bf16$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$4,"gemm_f32_bf16"); + return RuntimeHelper.requireNonNull(constants$2.const$4, "gemm_f32_bf16"); } /** * {@snippet : * void gemm_f32_bf16(int flags, float* a, int aoffset, short* b, int boffset, short* cr, float* r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_bf16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_bf16( + int flags, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_f32_bf16$MH(); try { mh$.invokeExact(flags, a, aoffset, b, boffset, cr, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -201,15 +370,32 @@ public static void gemm_f32_bf16(int flags, MemorySegment a, int aoffset, Memory throw new AssertionError("should not reach here", ex$); } } + public static MethodHandle gemm_f32_bf16_batch$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$5,"gemm_f32_bf16_batch"); + return RuntimeHelper.requireNonNull(constants$2.const$5, "gemm_f32_bf16_batch"); } /** * {@snippet : * void gemm_f32_bf16_batch(int flags, int batch_num, float* a, int aoffset, short** b, int boffset, short** cr, float** r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc); * } */ - public static void gemm_f32_bf16_batch(int flags, int batch_num, MemorySegment a, int aoffset, MemorySegment b, int boffset, MemorySegment cr, MemorySegment r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) { + public static void gemm_f32_bf16_batch( + int flags, + int batch_num, + MemorySegment a, + int aoffset, + MemorySegment b, + int boffset, + MemorySegment cr, + MemorySegment r, + int roffset, + int m, + int n0, + int n, + int k, + int lda, + int ldb, + int ldc) { var mh$ = gemm_f32_bf16_batch$MH(); try { mh$.invokeExact(flags, batch_num, a, aoffset, b, boffset, cr, r, roffset, m, n0, n, k, lda, ldb, ldc); @@ -218,5 +404,3 @@ public static void gemm_f32_bf16_batch(int flags, int batch_num, MemorySegment a } } } - - diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java index 46b0ff4..6deffb7 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java @@ -1,30 +1,37 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; // Generated by jextract -import java.lang.foreign.Linker; +import static java.lang.foreign.Linker.*; +import static java.lang.foreign.ValueLayout.*; + +import java.lang.foreign.AddressLayout; +import java.lang.foreign.Arena; import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.GroupLayout; -import java.lang.foreign.SymbolLookup; +import java.lang.foreign.Linker; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemorySegment; -import java.lang.foreign.Arena; import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.SymbolLookup; import java.lang.foreign.ValueLayout; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; -import java.io.File; -import java.nio.file.Path; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Optional; -import java.util.stream.Stream; - -import java.lang.foreign.AddressLayout; -import java.lang.foreign.MemoryLayout; - -import static java.lang.foreign.Linker.*; -import static java.lang.foreign.ValueLayout.*; final class RuntimeHelper { @@ -32,10 +39,12 @@ final class RuntimeHelper { private static final ClassLoader LOADER = RuntimeHelper.class.getClassLoader(); private static final MethodHandles.Lookup MH_LOOKUP = MethodHandles.lookup(); private static final SymbolLookup SYMBOL_LOOKUP; - private static final SegmentAllocator THROWING_ALLOCATOR = (x, y) -> { throw new AssertionError("should not reach here"); }; + private static final SegmentAllocator THROWING_ALLOCATOR = (x, y) -> { + throw new AssertionError("should not reach here"); + }; static final AddressLayout POINTER = ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE)); - final static SegmentAllocator CONSTANT_ALLOCATOR = + static final SegmentAllocator CONSTANT_ALLOCATOR = (size, align) -> Arena.ofAuto().allocate(size, align); static { @@ -43,7 +52,8 @@ final class RuntimeHelper { System.loadLibrary("jlama"); } SymbolLookup loaderLookup = SymbolLookup.loaderLookup(); - SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name)); + SYMBOL_LOOKUP = + name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name)); } // Suppresses default constructor, ensuring non-instantiability. @@ -57,15 +67,17 @@ static T requireNonNull(T obj, String symbolName) { } static MemorySegment lookupGlobalVariable(String name, MemoryLayout layout) { - return SYMBOL_LOOKUP.find(name) + return SYMBOL_LOOKUP + .find(name) .map(s -> s.reinterpret(layout.byteSize())) .orElse(null); } static MethodHandle downcallHandle(String name, FunctionDescriptor fdesc) { - return SYMBOL_LOOKUP.find(name). - map(addr -> LINKER.downcallHandle(addr, fdesc)). - orElse(null); + return SYMBOL_LOOKUP + .find(name) + .map(addr -> LINKER.downcallHandle(addr, fdesc)) + .orElse(null); } static MethodHandle downcallHandle(FunctionDescriptor fdesc) { @@ -73,9 +85,10 @@ static MethodHandle downcallHandle(FunctionDescriptor fdesc) { } static MethodHandle downcallHandleVariadic(String name, FunctionDescriptor fdesc) { - return SYMBOL_LOOKUP.find(name). - map(addr -> VarargsInvoker.make(addr, fdesc)). - orElse(null); + return SYMBOL_LOOKUP + .find(name) + .map(addr -> VarargsInvoker.make(addr, fdesc)) + .orElse(null); } static MethodHandle upcallHandle(Class fi, String name, FunctionDescriptor fdesc) { @@ -96,7 +109,7 @@ static MemorySegment upcallStub(MethodHandle fiHandle, Z z, FunctionDescript } static MemorySegment asArray(MemorySegment addr, MemoryLayout layout, int numElements, Arena arena) { - return addr.reinterpret(numElements * layout.byteSize(), arena, null); + return addr.reinterpret(numElements * layout.byteSize(), arena, null); } // Internals only below this point @@ -113,7 +126,11 @@ private VarargsInvoker(MemorySegment symbol, FunctionDescriptor function) { static { try { - INVOKE_MH = MethodHandles.lookup().findVirtual(VarargsInvoker.class, "invoke", MethodType.methodType(Object.class, SegmentAllocator.class, Object[].class)); + INVOKE_MH = MethodHandles.lookup() + .findVirtual( + VarargsInvoker.class, + "invoke", + MethodType.methodType(Object.class, SegmentAllocator.class, Object[].class)); } catch (ReflectiveOperationException e) { throw new RuntimeException(e); } @@ -121,14 +138,19 @@ private VarargsInvoker(MemorySegment symbol, FunctionDescriptor function) { static MethodHandle make(MemorySegment symbol, FunctionDescriptor function) { VarargsInvoker invoker = new VarargsInvoker(symbol, function); - MethodHandle handle = INVOKE_MH.bindTo(invoker).asCollector(Object[].class, function.argumentLayouts().size() + 1); - MethodType mtype = MethodType.methodType(function.returnLayout().isPresent() ? carrier(function.returnLayout().get(), true) : void.class); + MethodHandle handle = INVOKE_MH + .bindTo(invoker) + .asCollector(Object[].class, function.argumentLayouts().size() + 1); + MethodType mtype = MethodType.methodType( + function.returnLayout().isPresent() + ? carrier(function.returnLayout().get(), true) + : void.class); for (MemoryLayout layout : function.argumentLayouts()) { mtype = mtype.appendParameterTypes(carrier(layout, false)); } mtype = mtype.appendParameterTypes(Object[].class); - boolean needsAllocator = function.returnLayout().isPresent() && - function.returnLayout().get() instanceof GroupLayout; + boolean needsAllocator = function.returnLayout().isPresent() + && function.returnLayout().get() instanceof GroupLayout; if (needsAllocator) { mtype = mtype.insertParameterTypes(0, SegmentAllocator.class); } else { @@ -150,7 +172,7 @@ static Class carrier(MemoryLayout layout, boolean ret) { private Object invoke(SegmentAllocator allocator, Object[] args) throws Throwable { // one trailing Object[] int nNamedArgs = function.argumentLayouts().size(); - assert(args.length == nNamedArgs + 1); + assert (args.length == nNamedArgs + 1); // The last argument is the array of vararg collector Object[] unnamedArgs = (Object[]) args[args.length - 1]; @@ -164,18 +186,18 @@ private Object invoke(SegmentAllocator allocator, Object[] args) throws Throwabl } assert pos == nNamedArgs; - for (Object o: unnamedArgs) { + for (Object o : unnamedArgs) { argLayouts[pos] = variadicLayout(normalize(o.getClass())); pos++; } assert pos == argsCount; - FunctionDescriptor f = (function.returnLayout().isEmpty()) ? - FunctionDescriptor.ofVoid(argLayouts) : - FunctionDescriptor.of(function.returnLayout().get(), argLayouts); + FunctionDescriptor f = (function.returnLayout().isEmpty()) + ? FunctionDescriptor.ofVoid(argLayouts) + : FunctionDescriptor.of(function.returnLayout().get(), argLayouts); MethodHandle mh = LINKER.downcallHandle(symbol, f); - boolean needsAllocator = function.returnLayout().isPresent() && - function.returnLayout().get() instanceof GroupLayout; + boolean needsAllocator = function.returnLayout().isPresent() + && function.returnLayout().get() instanceof GroupLayout; if (needsAllocator) { mh = mh.bindTo(allocator); } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java index 1e1ab51..b6000a2 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java @@ -1,85 +1,85 @@ -// Generated by jextract - +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.VarHandle; -import java.nio.ByteOrder; -import java.lang.foreign.*; import static java.lang.foreign.ValueLayout.*; + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; + final class constants$0 { // Suppresses default constructor, ensuring non-instantiability. private constants$0() {} + static final FunctionDescriptor const$0 = FunctionDescriptor.ofVoid( - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$1 = RuntimeHelper.downcallHandle( - "gemm_q8_q4", - constants$0.const$0 - ); + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$1 = RuntimeHelper.downcallHandle("gemm_q8_q4", constants$0.const$0); static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid( - JAVA_INT, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$3 = RuntimeHelper.downcallHandle( - "gemm_q8_q4_batch", - constants$0.const$2 - ); + JAVA_INT, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$3 = RuntimeHelper.downcallHandle("gemm_q8_q4_batch", constants$0.const$2); static final FunctionDescriptor const$4 = FunctionDescriptor.ofVoid( - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$5 = RuntimeHelper.downcallHandle( - "gemm_f32", - constants$0.const$4 - ); + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$5 = RuntimeHelper.downcallHandle("gemm_f32", constants$0.const$4); } - - diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java index c1ffd24..e6ea355 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java @@ -1,82 +1,82 @@ -// Generated by jextract - +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.VarHandle; -import java.nio.ByteOrder; -import java.lang.foreign.*; import static java.lang.foreign.ValueLayout.*; + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; + final class constants$1 { // Suppresses default constructor, ensuring non-instantiability. private constants$1() {} + static final FunctionDescriptor const$0 = FunctionDescriptor.ofVoid( - JAVA_INT, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$1 = RuntimeHelper.downcallHandle( - "gemm_f32_batch", - constants$1.const$0 - ); + JAVA_INT, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$1 = RuntimeHelper.downcallHandle("gemm_f32_batch", constants$1.const$0); static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid( - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$3 = RuntimeHelper.downcallHandle( - "gemm_f32_q4", - constants$1.const$2 - ); + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$3 = RuntimeHelper.downcallHandle("gemm_f32_q4", constants$1.const$2); static final FunctionDescriptor const$4 = FunctionDescriptor.ofVoid( - JAVA_INT, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$5 = RuntimeHelper.downcallHandle( - "gemm_f32_q4_batch", - constants$1.const$4 - ); + JAVA_INT, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$5 = RuntimeHelper.downcallHandle("gemm_f32_q4_batch", constants$1.const$4); } - - diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java index d644d5a..6c982d0 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java @@ -1,67 +1,65 @@ -// Generated by jextract - +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor.operations.cnative; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.VarHandle; -import java.nio.ByteOrder; -import java.lang.foreign.*; import static java.lang.foreign.ValueLayout.*; + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; + final class constants$2 { // Suppresses default constructor, ensuring non-instantiability. private constants$2() {} + static final FunctionDescriptor const$0 = FunctionDescriptor.ofVoid( - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$1 = RuntimeHelper.downcallHandle( - "gemm_bf16", - constants$2.const$0 - ); + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$1 = RuntimeHelper.downcallHandle("gemm_bf16", constants$2.const$0); static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid( - JAVA_INT, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$3 = RuntimeHelper.downcallHandle( - "gemm_bf16_batch", - constants$2.const$2 - ); - static final MethodHandle const$4 = RuntimeHelper.downcallHandle( - "gemm_f32_bf16", - constants$2.const$0 - ); - static final MethodHandle const$5 = RuntimeHelper.downcallHandle( - "gemm_f32_bf16_batch", - constants$2.const$2 - ); + JAVA_INT, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT); + static final MethodHandle const$3 = RuntimeHelper.downcallHandle("gemm_bf16_batch", constants$2.const$2); + static final MethodHandle const$4 = RuntimeHelper.downcallHandle("gemm_f32_bf16", constants$2.const$0); + static final MethodHandle const$5 = RuntimeHelper.downcallHandle("gemm_f32_bf16_batch", constants$2.const$2); } - - diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java index 5818446..4af234d 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java @@ -19,7 +19,6 @@ import com.github.tjake.jlama.tensor.*; import com.github.tjake.jlama.tensor.operations.TensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; - import java.util.Collection; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; @@ -36,7 +35,11 @@ @Fork( warmups = 1, value = 1, - jvmArgsPrepend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djlama.force_panama_tensor_operations=true"}) + jvmArgsPrepend = { + "--add-modules=jdk.incubator.vector", + "--enable-preview", + "-Djlama.force_panama_tensor_operations=true" + }) public class BatchBench { private static final TensorOperations ops = TensorOperationsProvider.get(); @@ -175,7 +178,10 @@ public static void main(String[] args) throws Exception { .warmupTime(TimeValue.seconds(5)) .measurementTime(TimeValue.seconds(5)) .threads(1) - .jvmArgs("--add-modules=jdk.incubator.vector", "--enable-preview", "-Djava.library.path=jlama-native/target/native-lib-only") + .jvmArgs( + "--add-modules=jdk.incubator.vector", + "--enable-preview", + "-Djava.library.path=jlama-native/target/native-lib-only") .build(); Collection results = new Runner(opt).run(); @@ -184,7 +190,8 @@ public static void main(String[] args) throws Exception { for (RunResult r : results) { for (var b : r.getBenchmarkResults()) { - double elapsedTime = TimeUnit.MILLISECONDS.toSeconds(b.getMetadata().getStopTime() - b.getMetadata().getMeasurementTime()); + double elapsedTime = TimeUnit.MILLISECONDS.toSeconds( + b.getMetadata().getStopTime() - b.getMetadata().getMeasurementTime()); // Calculate total number of floating-point operations double totalFlops = flopsPerIteration * b.getMetadata().getMeasurementOps(); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index cfc7497..cab6787 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -66,7 +66,7 @@ public class TestModels { static { System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0"); - //System.setProperty("jlama.force_panama_tensor_operations", "true"); + // System.setProperty("jlama.force_panama_tensor_operations", "true"); } private static final Logger logger = LoggerFactory.getLogger(TestModels.class); @@ -153,8 +153,11 @@ public void MixtralRun() throws Exception { + "or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance. Explain the above in one sentence:"; String prompt = "Tell me a joke."; - String p = model.promptSupport().get().newBuilder() - .addUserMessage(prompt).build(); + String p = model.promptSupport() + .get() + .newBuilder() + .addUserMessage(prompt) + .build(); model.generate(UUID.randomUUID(), p, 0.7f, 256, true, makeOutHandler()); } @@ -170,7 +173,9 @@ public void GemmaRun() throws Exception { GemmaConfig c = om.readValue(new File(modelPrefix + "/config.json"), GemmaConfig.class); GemmaModel model = new GemmaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); String prompt = "Tell me a joke."; - String p = model.promptSupport().get().newBuilder() + String p = model.promptSupport() + .get() + .newBuilder() .addUserMessage(prompt) .build(); model.generate(UUID.randomUUID(), p, 0.7f, 256, false, makeOutHandler()); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java index 86242c8..55f2861 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java @@ -299,7 +299,10 @@ public void testScale() { Assert.assertEquals( "AType " + aType.getKey() + " is outside of 1% error limit", control, dp, control * .01f); } catch (UnsupportedOperationException | IllegalArgumentException e) { - logger.debug("No support for AType {} {}", aType.getKey(), t.getClass().getSimpleName()); + logger.debug( + "No support for AType {} {}", + aType.getKey(), + t.getClass().getSimpleName()); } } Assert.assertTrue(supported > 0); @@ -359,16 +362,22 @@ public void testQ8Vectorized() { AbstractTensor qv = t.quantize(at, DType.I8, 0, SIZE); supported++; Assert.assertEquals( - "AType " + aType.getKey() + " is outside of 1% error limit " + t.getClass().getSimpleName(), control, controlOps.sum(qv), control * .01f); + "AType " + aType.getKey() + " is outside of 1% error limit " + + t.getClass().getSimpleName(), + control, + controlOps.sum(qv), + control * .01f); } catch (UnsupportedOperationException | IllegalArgumentException e) { - logger.debug("No support for AType {} {}", aType.getKey(), t.getClass().getSimpleName()); + logger.debug( + "No support for AType {} {}", + aType.getKey(), + t.getClass().getSimpleName()); } } Assert.assertTrue(supported > 0); } } - @Test public void testQBF16Vectorized() { AbstractTensor a = makeTensor(SIZE); @@ -383,9 +392,15 @@ public void testQBF16Vectorized() { AbstractTensor qv = t.quantize(at, DType.BF16, 0, (int) a.size()); supported++; Assert.assertEquals( - "AType " + aType.getKey() + " is outside of 1% error limit", control, controlOps.sum(qv), control * .01f); + "AType " + aType.getKey() + " is outside of 1% error limit", + control, + controlOps.sum(qv), + control * .01f); } catch (UnsupportedOperationException | IllegalArgumentException e) { - logger.debug("No support for AType {} {}", aType.getKey(), t.getClass().getSimpleName()); + logger.debug( + "No support for AType {} {}", + aType.getKey(), + t.getClass().getSimpleName()); } } Assert.assertTrue(supported > 0);