Skip to content

Commit

Permalink
rm f16 native
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 21, 2023
1 parent 8e1f24c commit 4a9b6b2
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 112 deletions.
7 changes: 1 addition & 6 deletions jlama-native/src/main/c/vector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
#define Q8_BLOCK_SIZE 32
#define Q4_BLOCK_SIZE 32

//F16
float dot_product_f16(int flags, const short* a, int aoffset, const short* b, int boffset, int length);
float dot_product_f16_q8(int flags, const short* a, int aoffset, const float *bf, const char* b, int boffset, int length);
float dot_product_f16_q4(int flags, const short* a, int aoffset, const float *bf, const char* b, int boffset, int length);

//F32
float dot_product_f32(int flags, const float* a, int aoffset, const float* b, int boffset, int length);
void dot_product_f32_chunked(int flags, float *r, const float* a, int aoffset, const float* b, int boffset, int length, int bchunkstart, int bchunksize);
Expand All @@ -30,4 +25,4 @@ float dot_product_q8(int flags, const float *af, const char* a, int aoffset, con
float dot_product_q8_q4(int flags, const float *af, const char* a, int aoffset, const float *bf, const char* b, int boffset, int length);
void dot_product_q8_q4_chunked(int flags, float *r, const float* af, const char *a, int aoffset, const float *bf, const char* b, int boffset, int length, int bchunkstart, int bchunksize);

#endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int bof
case Q4 -> NativeSimd.dot_product_f32_q4(flags, a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
default -> throw new UnsupportedOperationException();
};
case F16 -> switch (b.dType()) {
case F16 -> NativeSimd.dot_product_f16(flags, a.getMemorySegment(), aoffset, b.getMemorySegment(), boffset, limit);
case I8 -> NativeSimd.dot_product_f16_q8(flags, a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
default -> throw new UnsupportedOperationException();
};
case I8 -> switch (b.dType()) {
case Q4 -> NativeSimd.dot_product_q8_q4(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
//case I8 -> NativeSimd.dot_product_q8(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,56 +57,8 @@ public static int Q8_BLOCK_SIZE() {
public static int Q4_BLOCK_SIZE() {
return (int)32L;
}
public static MethodHandle dot_product_f16$MH() {
return RuntimeHelper.requireNonNull(constants$0.const$1,"dot_product_f16");
}
/**
* {@snippet :
* float dot_product_f16(int flags, short* a, int aoffset, short* b, int boffset, int length);
* }
*/
public static float dot_product_f16(int flags, MemorySegment a, int aoffset, MemorySegment b, int boffset, int length) {
var mh$ = dot_product_f16$MH();
try {
return (float)mh$.invokeExact(flags, a, aoffset, b, boffset, length);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
public static MethodHandle dot_product_f16_q8$MH() {
return RuntimeHelper.requireNonNull(constants$0.const$3,"dot_product_f16_q8");
}
/**
* {@snippet :
* float dot_product_f16_q8(int flags, short* a, int aoffset, float* bf, char* b, int boffset, int length);
* }
*/
public static float dot_product_f16_q8(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, int length) {
var mh$ = dot_product_f16_q8$MH();
try {
return (float)mh$.invokeExact(flags, a, aoffset, bf, b, boffset, length);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
public static MethodHandle dot_product_f16_q4$MH() {
return RuntimeHelper.requireNonNull(constants$0.const$4,"dot_product_f16_q4");
}
/**
* {@snippet :
* float dot_product_f16_q4(int flags, short* a, int aoffset, float* bf, char* b, int boffset, int length);
* }
*/
public static float dot_product_f16_q4(int flags, MemorySegment a, int aoffset, MemorySegment bf, MemorySegment b, int boffset, int length) {
var mh$ = dot_product_f16_q4$MH();
try {
return (float)mh$.invokeExact(flags, a, aoffset, bf, b, boffset, length);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
public static MethodHandle dot_product_f32$MH() {
return RuntimeHelper.requireNonNull(constants$0.const$5,"dot_product_f32");
return RuntimeHelper.requireNonNull(constants$0.const$1,"dot_product_f32");
}
/**
* {@snippet :
Expand All @@ -122,7 +74,7 @@ public static float dot_product_f32(int flags, MemorySegment a, int aoffset, Mem
}
}
public static MethodHandle dot_product_f32_chunked$MH() {
return RuntimeHelper.requireNonNull(constants$1.const$1,"dot_product_f32_chunked");
return RuntimeHelper.requireNonNull(constants$0.const$3,"dot_product_f32_chunked");
}
/**
* {@snippet :
Expand All @@ -138,7 +90,7 @@ public static void dot_product_f32_chunked(int flags, MemorySegment r, MemorySeg
}
}
public static MethodHandle dot_product_f32_q8$MH() {
return RuntimeHelper.requireNonNull(constants$1.const$2,"dot_product_f32_q8");
return RuntimeHelper.requireNonNull(constants$0.const$5,"dot_product_f32_q8");
}
/**
* {@snippet :
Expand All @@ -154,7 +106,7 @@ public static float dot_product_f32_q8(int flags, MemorySegment a, int aoffset,
}
}
public static MethodHandle dot_product_f32_q8_chunked$MH() {
return RuntimeHelper.requireNonNull(constants$1.const$4,"dot_product_f32_q8_chunked");
return RuntimeHelper.requireNonNull(constants$1.const$1,"dot_product_f32_q8_chunked");
}
/**
* {@snippet :
Expand All @@ -170,7 +122,7 @@ public static void dot_product_f32_q8_chunked(int flags, MemorySegment r, Memory
}
}
public static MethodHandle dot_product_f32_q4$MH() {
return RuntimeHelper.requireNonNull(constants$1.const$5,"dot_product_f32_q4");
return RuntimeHelper.requireNonNull(constants$1.const$2,"dot_product_f32_q4");
}
/**
* {@snippet :
Expand All @@ -186,7 +138,7 @@ public static float dot_product_f32_q4(int flags, MemorySegment a, int aoffset,
}
}
public static MethodHandle dot_product_f32_q4_chunked$MH() {
return RuntimeHelper.requireNonNull(constants$2.const$0,"dot_product_f32_q4_chunked");
return RuntimeHelper.requireNonNull(constants$1.const$3,"dot_product_f32_q4_chunked");
}
/**
* {@snippet :
Expand All @@ -202,7 +154,7 @@ public static void dot_product_f32_q4_chunked(int flags, MemorySegment r, Memory
}
}
public static MethodHandle dot_product_q8$MH() {
return RuntimeHelper.requireNonNull(constants$2.const$2,"dot_product_q8");
return RuntimeHelper.requireNonNull(constants$1.const$5,"dot_product_q8");
}
/**
* {@snippet :
Expand All @@ -218,7 +170,7 @@ public static float dot_product_q8(int flags, MemorySegment af, MemorySegment a,
}
}
public static MethodHandle dot_product_q8_q4$MH() {
return RuntimeHelper.requireNonNull(constants$2.const$3,"dot_product_q8_q4");
return RuntimeHelper.requireNonNull(constants$2.const$0,"dot_product_q8_q4");
}
/**
* {@snippet :
Expand All @@ -234,7 +186,7 @@ public static float dot_product_q8_q4(int flags, MemorySegment af, MemorySegment
}
}
public static MethodHandle dot_product_q8_q4_chunked$MH() {
return RuntimeHelper.requireNonNull(constants$2.const$5,"dot_product_q8_q4_chunked");
return RuntimeHelper.requireNonNull(constants$2.const$2,"dot_product_q8_q4_chunked");
}
/**
* {@snippet :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ final class RuntimeHelper {
if (!JarSupport.maybeLoadLibrary()) {
System.loadLibrary("jlama");
}
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,36 @@ final class constants$0 {
JAVA_INT
);
static final MethodHandle const$1 = RuntimeHelper.downcallHandle(
"dot_product_f16",
"dot_product_f32",
constants$0.const$0
);
static final FunctionDescriptor const$2 = FunctionDescriptor.of(JAVA_FLOAT,
static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid(
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$3 = RuntimeHelper.downcallHandle(
"dot_product_f16_q8",
"dot_product_f32_chunked",
constants$0.const$2
);
static final MethodHandle const$4 = RuntimeHelper.downcallHandle(
"dot_product_f16_q4",
constants$0.const$2
static final FunctionDescriptor const$4 = FunctionDescriptor.of(JAVA_FLOAT,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$5 = RuntimeHelper.downcallHandle(
"dot_product_f32",
constants$0.const$0
"dot_product_f32_q8",
constants$0.const$4
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,37 @@ final class constants$1 {
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$1 = RuntimeHelper.downcallHandle(
"dot_product_f32_chunked",
"dot_product_f32_q8_chunked",
constants$1.const$0
);
static final MethodHandle const$2 = RuntimeHelper.downcallHandle(
"dot_product_f32_q8",
constants$0.const$2
"dot_product_f32_q4",
constants$0.const$4
);
static final MethodHandle const$3 = RuntimeHelper.downcallHandle(
"dot_product_f32_q4_chunked",
constants$1.const$0
);
static final FunctionDescriptor const$3 = FunctionDescriptor.ofVoid(
static final FunctionDescriptor const$4 = FunctionDescriptor.of(JAVA_FLOAT,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$4 = RuntimeHelper.downcallHandle(
"dot_product_f32_q8_chunked",
constants$1.const$3
);
static final MethodHandle const$5 = RuntimeHelper.downcallHandle(
"dot_product_f32_q4",
constants$0.const$2
"dot_product_q8",
constants$1.const$4
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,10 @@ final class constants$2 {
// Suppresses default constructor, ensuring non-instantiability.
private constants$2() {}
static final MethodHandle const$0 = RuntimeHelper.downcallHandle(
"dot_product_f32_q4_chunked",
constants$1.const$3
);
static final FunctionDescriptor const$1 = FunctionDescriptor.of(JAVA_FLOAT,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$2 = RuntimeHelper.downcallHandle(
"dot_product_q8",
constants$2.const$1
);
static final MethodHandle const$3 = RuntimeHelper.downcallHandle(
"dot_product_q8_q4",
constants$2.const$1
constants$1.const$4
);
static final FunctionDescriptor const$4 = FunctionDescriptor.ofVoid(
static final FunctionDescriptor const$1 = FunctionDescriptor.ofVoid(
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
Expand All @@ -46,9 +28,9 @@ final class constants$2 {
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$5 = RuntimeHelper.downcallHandle(
static final MethodHandle const$2 = RuntimeHelper.downcallHandle(
"dot_product_q8_q4_chunked",
constants$2.const$4
constants$2.const$1
);
}

Expand Down

0 comments on commit 4a9b6b2

Please sign in to comment.