diff --git a/jlama-native/src/main/c/vector_simd.h b/jlama-native/src/main/c/vector_simd.h index ae80f27..46a14ff 100644 --- a/jlama-native/src/main/c/vector_simd.h +++ b/jlama-native/src/main/c/vector_simd.h @@ -10,11 +10,6 @@ #define Q8_BLOCK_SIZE 32 #define Q4_BLOCK_SIZE 32 -//F16 -float dot_product_f16(int flags, const short* a, int aoffset, const short* b, int boffset, int length); -float dot_product_f16_q8(int flags, const short* a, int aoffset, const float *bf, const char* b, int boffset, int length); -float dot_product_f16_q4(int flags, const short* a, int aoffset, const float *bf, const char* b, int boffset, int length); - //F32 float dot_product_f32(int flags, const float* a, int aoffset, const float* b, int boffset, int length); void dot_product_f32_chunked(int flags, float *r, const float* a, int aoffset, const float* b, int boffset, int length, int bchunkstart, int bchunksize); @@ -30,4 +25,4 @@ float dot_product_q8(int flags, const float *af, const char* a, int aoffset, con float dot_product_q8_q4(int flags, const float *af, const char* a, int aoffset, const float *bf, const char* b, int boffset, int length); void dot_product_q8_q4_chunked(int flags, float *r, const float* af, const char *a, int aoffset, const float *bf, const char* b, int boffset, int length, int bchunkstart, int bchunksize); -#endif \ No newline at end of file +#endif diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java index 31e0778..037b9ad 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java @@ -72,11 +72,6 @@ public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int bof case Q4 -> NativeSimd.dot_product_f32_q4(flags, a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); default -> throw new UnsupportedOperationException(); }; - case F16 -> switch (b.dType()) { - case F16 -> NativeSimd.dot_product_f16(flags, a.getMemorySegment(), aoffset, b.getMemorySegment(), boffset, limit); - case I8 -> NativeSimd.dot_product_f16_q8(flags, a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); - default -> throw new UnsupportedOperationException(); - }; case I8 -> switch (b.dType()) { case Q4 -> NativeSimd.dot_product_q8_q4(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); //case I8 -> NativeSimd.dot_product_q8(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java index 9358499..e6dbea2 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/NativeSimd.java @@ -57,56 +57,8 @@ public static int Q8_BLOCK_SIZE() { public static int Q4_BLOCK_SIZE() { return (int)32L; } - public static MethodHandle dot_product_f16$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$1,"dot_product_f16"); - } - /** - * {@snippet : - * float dot_product_f16(int flags, short* a, int aoffset, short* b, int boffset, int length); - * } - */ - public static float dot_product_f16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, int length) { - var mh$ = dot_product_f16$MH(); - try { - return (float)mh$.invokeExact(flags, a, aoffset, b, boffset, length); - } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); - } - } - public static MethodHandle dot_product_f16_q8$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$3,"dot_product_f16_q8"); - } - /** - * {@snippet : - * float dot_product_f16_q8(int flags, short* a, int aoffset, float* bf, char* b, int boffset, int length); - * } - */ - public static float dot_product_f16_q8(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, int length) { - var mh$ = dot_product_f16_q8$MH(); - try { - return (float)mh$.invokeExact(flags, a, aoffset, bf, b, boffset, length); - } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); - } - } - public static MethodHandle dot_product_f16_q4$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$4,"dot_product_f16_q4"); - } - /** - * {@snippet : - * float dot_product_f16_q4(int flags, short* a, int aoffset, float* bf, char* b, int boffset, int length); - * } - */ - public static float dot_product_f16_q4(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, int length) { - var mh$ = dot_product_f16_q4$MH(); - try { - return (float)mh$.invokeExact(flags, a, aoffset, bf, b, boffset, length); - } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); - } - } public static MethodHandle dot_product_f32$MH() { - return RuntimeHelper.requireNonNull(constants$0.const$5,"dot_product_f32"); + return RuntimeHelper.requireNonNull(constants$0.const$1,"dot_product_f32"); } /** * {@snippet : @@ -122,7 +74,7 @@ public static float dot_product_f32(int flags, MemorySegment a, int aoffset, Mem } } public static MethodHandle dot_product_f32_chunked$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$1,"dot_product_f32_chunked"); + return RuntimeHelper.requireNonNull(constants$0.const$3,"dot_product_f32_chunked"); } /** * {@snippet : @@ -138,7 +90,7 @@ public static void dot_product_f32_chunked(int flags, MemorySegment r, MemorySeg } } public static MethodHandle dot_product_f32_q8$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$2,"dot_product_f32_q8"); + return RuntimeHelper.requireNonNull(constants$0.const$5,"dot_product_f32_q8"); } /** * {@snippet : @@ -154,7 +106,7 @@ public static float dot_product_f32_q8(int flags, MemorySegment a, int aoffset, } } public static MethodHandle dot_product_f32_q8_chunked$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$4,"dot_product_f32_q8_chunked"); + return RuntimeHelper.requireNonNull(constants$1.const$1,"dot_product_f32_q8_chunked"); } /** * {@snippet : @@ -170,7 +122,7 @@ public static void dot_product_f32_q8_chunked(int flags, MemorySegment r, Memory } } public static MethodHandle dot_product_f32_q4$MH() { - return RuntimeHelper.requireNonNull(constants$1.const$5,"dot_product_f32_q4"); + return RuntimeHelper.requireNonNull(constants$1.const$2,"dot_product_f32_q4"); } /** * {@snippet : @@ -186,7 +138,7 @@ public static float dot_product_f32_q4(int flags, MemorySegment a, int aoffset, } } public static MethodHandle dot_product_f32_q4_chunked$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$0,"dot_product_f32_q4_chunked"); + return RuntimeHelper.requireNonNull(constants$1.const$3,"dot_product_f32_q4_chunked"); } /** * {@snippet : @@ -202,7 +154,7 @@ public static void dot_product_f32_q4_chunked(int flags, MemorySegment r, Memory } } public static MethodHandle dot_product_q8$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$2,"dot_product_q8"); + return RuntimeHelper.requireNonNull(constants$1.const$5,"dot_product_q8"); } /** * {@snippet : @@ -218,7 +170,7 @@ public static float dot_product_q8(int flags, MemorySegment af, MemorySegment a, } } public static MethodHandle dot_product_q8_q4$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$3,"dot_product_q8_q4"); + return RuntimeHelper.requireNonNull(constants$2.const$0,"dot_product_q8_q4"); } /** * {@snippet : @@ -234,7 +186,7 @@ public static float dot_product_q8_q4(int flags, MemorySegment af, MemorySegment } } public static MethodHandle dot_product_q8_q4_chunked$MH() { - return RuntimeHelper.requireNonNull(constants$2.const$5,"dot_product_q8_q4_chunked"); + return RuntimeHelper.requireNonNull(constants$2.const$2,"dot_product_q8_q4_chunked"); } /** * {@snippet : diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java index 46b0ff4..cc99930 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/RuntimeHelper.java @@ -42,7 +42,7 @@ final class RuntimeHelper { if (!JarSupport.maybeLoadLibrary()) { System.loadLibrary("jlama"); } - SymbolLookup loaderLookup = SymbolLookup.loaderLookup(); + SymbolLookup loaderLookup = SymbolLookup.loaderLookup(); SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name)); } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java index 11956a7..71a1f5d 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$0.java @@ -20,29 +20,36 @@ final class constants$0 { JAVA_INT ); static final MethodHandle const$1 = RuntimeHelper.downcallHandle( - "dot_product_f16", + "dot_product_f32", constants$0.const$0 ); - static final FunctionDescriptor const$2 = FunctionDescriptor.of(JAVA_FLOAT, + static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid( JAVA_INT, RuntimeHelper.POINTER, - JAVA_INT, RuntimeHelper.POINTER, + JAVA_INT, RuntimeHelper.POINTER, JAVA_INT, + JAVA_INT, + JAVA_INT, JAVA_INT ); static final MethodHandle const$3 = RuntimeHelper.downcallHandle( - "dot_product_f16_q8", + "dot_product_f32_chunked", constants$0.const$2 ); - static final MethodHandle const$4 = RuntimeHelper.downcallHandle( - "dot_product_f16_q4", - constants$0.const$2 + static final FunctionDescriptor const$4 = FunctionDescriptor.of(JAVA_FLOAT, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT ); static final MethodHandle const$5 = RuntimeHelper.downcallHandle( - "dot_product_f32", - constants$0.const$0 + "dot_product_f32_q8", + constants$0.const$4 ); } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java index 6e4564a..81ca980 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$1.java @@ -17,20 +17,25 @@ final class constants$1 { RuntimeHelper.POINTER, JAVA_INT, RuntimeHelper.POINTER, + RuntimeHelper.POINTER, JAVA_INT, JAVA_INT, JAVA_INT, JAVA_INT ); static final MethodHandle const$1 = RuntimeHelper.downcallHandle( - "dot_product_f32_chunked", + "dot_product_f32_q8_chunked", constants$1.const$0 ); static final MethodHandle const$2 = RuntimeHelper.downcallHandle( - "dot_product_f32_q8", - constants$0.const$2 + "dot_product_f32_q4", + constants$0.const$4 + ); + static final MethodHandle const$3 = RuntimeHelper.downcallHandle( + "dot_product_f32_q4_chunked", + constants$1.const$0 ); - static final FunctionDescriptor const$3 = FunctionDescriptor.ofVoid( + static final FunctionDescriptor const$4 = FunctionDescriptor.of(JAVA_FLOAT, JAVA_INT, RuntimeHelper.POINTER, RuntimeHelper.POINTER, @@ -38,17 +43,11 @@ final class constants$1 { RuntimeHelper.POINTER, RuntimeHelper.POINTER, JAVA_INT, - JAVA_INT, - JAVA_INT, JAVA_INT ); - static final MethodHandle const$4 = RuntimeHelper.downcallHandle( - "dot_product_f32_q8_chunked", - constants$1.const$3 - ); static final MethodHandle const$5 = RuntimeHelper.downcallHandle( - "dot_product_f32_q4", - constants$0.const$2 + "dot_product_q8", + constants$1.const$4 ); } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java index a8a25a2..8b4270b 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java @@ -12,28 +12,10 @@ final class constants$2 { // Suppresses default constructor, ensuring non-instantiability. private constants$2() {} static final MethodHandle const$0 = RuntimeHelper.downcallHandle( - "dot_product_f32_q4_chunked", - constants$1.const$3 - ); - static final FunctionDescriptor const$1 = FunctionDescriptor.of(JAVA_FLOAT, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - RuntimeHelper.POINTER, - RuntimeHelper.POINTER, - JAVA_INT, - JAVA_INT - ); - static final MethodHandle const$2 = RuntimeHelper.downcallHandle( - "dot_product_q8", - constants$2.const$1 - ); - static final MethodHandle const$3 = RuntimeHelper.downcallHandle( "dot_product_q8_q4", - constants$2.const$1 + constants$1.const$4 ); - static final FunctionDescriptor const$4 = FunctionDescriptor.ofVoid( + static final FunctionDescriptor const$1 = FunctionDescriptor.ofVoid( JAVA_INT, RuntimeHelper.POINTER, RuntimeHelper.POINTER, @@ -46,9 +28,9 @@ final class constants$2 { JAVA_INT, JAVA_INT ); - static final MethodHandle const$5 = RuntimeHelper.downcallHandle( + static final MethodHandle const$2 = RuntimeHelper.downcallHandle( "dot_product_q8_q4_chunked", - constants$2.const$4 + constants$2.const$1 ); }