diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java index f6eea66..5738337 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java @@ -9,6 +9,7 @@ import com.github.tjake.jlama.cli.commands.ChatCommand; import com.github.tjake.jlama.cli.commands.CompleteCommand; +import com.github.tjake.jlama.cli.commands.QuantizeCommand; import com.github.tjake.jlama.cli.commands.ServeCommand; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.ModelSupport.ModelType; @@ -30,6 +31,7 @@ public class JlamaCli implements Runnable { public static void main(String[] args) { CommandLine cli = new CommandLine(new JlamaCli()); + cli.addSubcommand("quantize", new QuantizeCommand()); cli.addSubcommand("chat", new ChatCommand()); cli.addSubcommand("complete", new CompleteCommand()); cli.addSubcommand("serve", new ServeCommand()); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java new file mode 100644 index 0000000..baa80cc --- /dev/null +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java @@ -0,0 +1,51 @@ +package com.github.tjake.jlama.cli.commands; + + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +import com.github.tjake.jlama.safetensors.DType; +import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import picocli.CommandLine; + +@CommandLine.Command(name = "quantize", description = "Quantize the specified model") + +public class QuantizeCommand extends BaseCommand { + + @CommandLine.Parameters(index = "1", arity = "0..1", description = "The output location") + protected Path output; + + @CommandLine.Option(names = { "-q", "--quantization"}, description = "Model quantization type", arity = "1") + protected DType modelQuantization; + + @CommandLine.Option(names = {"-s", "--skip-layer"}, description = "Layer name prefix to not quantize") + protected String[] skipLayerPrefixes; + + @Override + public void run() { + + if (!model.exists()) { + System.err.println("Model location does not exist: " + model); + System.exit(1); + } + + File baseDir = model.isFile() ? model.getParentFile() : model; + + if (!baseDir.isDirectory()) { + System.err.println("Model directory does not exist: " + baseDir); + System.exit(1); + } + + try { + Path out = SafeTensorSupport.quantizeModel(baseDir.toPath(), modelQuantization, skipLayerPrefixes, Optional.ofNullable(output)); + + System.out.println("Quantized model written to: " + out); + } catch (IOException e) { + e.printStackTrace(); + System.exit(2); + } + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java index 1a76ef2..7f1a076 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java @@ -24,7 +24,7 @@ public class SafeTensorIndex implements WeightLoader, AutoCloseable { public static final String SINGLE_MODEL_NAME = "model.safetensors"; public static final String MODEL_INDEX_JSON = "model.safetensors.index.json"; - private final Map metadata; + private final Map metadata; // Map from weight name to file name (this is what's in the JSON file) private final Map weightFileMap; @@ -136,6 +136,33 @@ private Map, List> computeMmapSplits(Map return splits; } + @JsonCreator + SafeTensorIndex(@JsonProperty("metadata") Map metadata, + @JsonProperty("weight_map") Map weightFileMap) { + this.metadata = ImmutableMap.copyOf(metadata); + this.weightFileMap = ImmutableMap.copyOf(weightFileMap); + } + + @Override + public Map metadata() { + return metadata; + } + + @Override + public Map tensorInfoMap() { + Map tensorInfoMap = new HashMap<>(); + for (String name : weightMap.keySet()) { + Weights w = weightMap.get(name); + if (w == null) + throw new NoSuchElementException(name); + + tensorInfoMap.put(name, w.tensorInfoMap().get(name)); + } + + return tensorInfoMap; + } + + @Override public AbstractTensor load(String name) { Weights w = weightMap.get(name); if (w == null) @@ -150,13 +177,6 @@ public DType getModelDType() { return weightMap.values().iterator().next().getModelDType(); } - @JsonCreator - SafeTensorIndex(@JsonProperty("metadata") Map metadata, - @JsonProperty("weight_map") Map weightFileMap) { - this.metadata = ImmutableMap.copyOf(metadata); - this.weightFileMap = ImmutableMap.copyOf(weightFileMap); - } - @Override public void close() throws Exception { weightMap.clear(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java index e1eb376..c87fedc 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java @@ -6,25 +6,37 @@ import com.fasterxml.jackson.databind.type.MapType; import com.github.tjake.jlama.model.ModelSupport.ModelType; import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel; +import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.Q4ByteBufferTensor; +import com.github.tjake.jlama.tensor.Q5ByteBufferTensor; +import com.github.tjake.jlama.tensor.Q8ByteBufferTensor; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import java.io.File; import java.io.IOException; +import java.io.OutputStream; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class SafeTensorSupport { + private static final Logger logger = LoggerFactory.getLogger(SafeTensorSupport.class); private static final ObjectMapper om = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); private static final MapType metadataTypeReference = om.getTypeFactory().constructMapType(Map.class, String.class, String.class); public static Map readTensorInfoMap(ByteBuffer buf, Optional> saveMetadata) { - long headerLength = buf.order() == ByteOrder.BIG_ENDIAN ? Long.reverseBytes(buf.getLong()) : buf.getLong(); + buf = buf.order(ByteOrder.LITTLE_ENDIAN); + long headerLength = buf.getLong(); byte[] header = new byte[Ints.checkedCast(headerLength)]; buf.get(header); @@ -76,7 +88,7 @@ public static WeightLoader loadWeights(File baseDir) throws IOException { if (Files.exists(Paths.get(baseDir.getAbsolutePath(), SafeTensorIndex.SINGLE_MODEL_NAME))) return SafeTensorIndex.loadSingleFile(baseDir.toPath(), SafeTensorIndex.SINGLE_MODEL_NAME); - throw new IllegalArgumentException("No safetensors model found in: " + baseDir); + throw new IllegalArgumentException("No safetensor model found in: " + baseDir); } public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException { @@ -89,4 +101,100 @@ public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException { return om.treeToValue(rootNode.get("model"), TokenizerModel.class); } + + public static Path quantizeModel(Path modelRoot, DType modelQuantization, String[] skipLayerPrefixes, Optional outputRoot) throws IOException { + File tmp = File.createTempFile("safe", "tensor"); + tmp.deleteOnExit(); + WeightLoader wl = SafeTensorSupport.loadWeights(modelRoot.toFile()); + Map writtenInfo = new HashMap<>(); + + try (RandomAccessFile raf = new RandomAccessFile(tmp, "rw")) { + Map tensors = wl.tensorInfoMap(); + + for (Map.Entry e : tensors.entrySet()) { + try (AbstractTensor tr = wl.load(e.getKey())) { + + boolean skipQ = false; + if (skipLayerPrefixes != null) { + for (String skipLayerPrefix : skipLayerPrefixes) { + if (e.getKey().startsWith(skipLayerPrefix)) { + skipQ = true; + break; + } + } + } + + AbstractTensor t = skipQ ? tr : tr.quantize(modelQuantization); + + switch (t.dType()) { + case F32: + case BF16: + case F16: + writtenInfo.put(e.getKey(), t.save(raf.getChannel())); + break; + case Q4: + writtenInfo.put(e.getKey(), t.save(raf.getChannel())); + writtenInfo.put(e.getKey() + ".qb", ((Q4ByteBufferTensor) t).getBlockF().save(raf.getChannel())); + break; + case Q5: + writtenInfo.put(e.getKey(), t.save(raf.getChannel())); + writtenInfo.put(e.getKey() + ".qb", ((Q5ByteBufferTensor) t).getBlockF().save(raf.getChannel())); + //FIXME: Need to add b5 bits + throw new UnsupportedOperationException("TODO"); + //break; + case I8: + writtenInfo.put(e.getKey(), t.save(raf.getChannel())); + writtenInfo.put(e.getKey() + ".qb", ((Q8ByteBufferTensor) t).getBlockF().save(raf.getChannel())); + break; + default: + throw new UnsupportedOperationException("" + t.dType() + " not implemented"); + } + } + } + } + + //Now create the output file + String baseDirName = modelRoot.getName(modelRoot.getNameCount() - 1).toString(); + Path parentPath = modelRoot.getParent(); + + Path qPath = outputRoot.orElseGet(() -> Paths.get(parentPath.toString(), baseDirName + "-jlama-" + modelQuantization.name())); + File qDir = qPath.toFile(); + qDir.mkdirs(); + + //Copy config.json and tokenizer.json + Files.copy(modelRoot.resolve("config.json"), qPath.resolve("config.json")); + Files.copy(modelRoot.resolve("tokenizer.json"), qPath.resolve("tokenizer.json")); + + try (RandomAccessFile raf = new RandomAccessFile(qPath.resolve("model.safetensors").toFile(), "rw")) { + FileChannel chan = raf.getChannel(); + + byte[] header = om.writeValueAsBytes(writtenInfo); + logger.debug("pos = {}", chan.position()); + byte[] hsize = new byte[Long.BYTES]; + ByteBuffer.wrap(hsize).order(ByteOrder.LITTLE_ENDIAN).putLong(header.length); + raf.write(hsize); + logger.debug("pos = {}", chan.position()); + raf.write(header); + logger.debug("pos = {}", chan.position()); + + Files.copy(tmp.toPath(), new OutputStream() { + @Override + public void write(int b) throws IOException { + raf.write(b); + } + + @Override + public void write(byte[] b) throws IOException { + raf.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + raf.write(b, off, len); + } + }); + } + + return qPath; + } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java index eb24a78..96ccf80 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java @@ -1,8 +1,15 @@ package com.github.tjake.jlama.safetensors; +import java.util.Map; + import com.github.tjake.jlama.tensor.AbstractTensor; -public interface WeightLoader { +public interface WeightLoader extends AutoCloseable { + + Map metadata(); + + Map tensorInfoMap(); + AbstractTensor load(String name); DType getModelDType(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java index 88b8195..f3fcffb 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java @@ -4,6 +4,10 @@ import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.Float16BufferTensor; import com.github.tjake.jlama.tensor.FloatBufferTensor; +import com.github.tjake.jlama.tensor.Q4ByteBufferTensor; +import com.github.tjake.jlama.tensor.Q8ByteBufferTensor; + +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import java.nio.ByteBuffer; @@ -16,20 +20,21 @@ public class Weights implements WeightLoader { private final Map metadata; private final Map tensorInfoMap; private final ByteBuffer bytes; - private final DType dType; + private final DType majorityDType; Weights(Map metadata, Map tensorInfoMap, ByteBuffer bytes) { - this.metadata = metadata; - this.tensorInfoMap = tensorInfoMap; + this.metadata = ImmutableMap.copyOf(metadata); + this.tensorInfoMap = ImmutableMap.copyOf(tensorInfoMap); this.bytes = bytes.duplicate(); - this.dType = findDType(); + this.majorityDType = findDType(); } private DType findDType() { EnumMap counts = new EnumMap<>(DType.class); - for (TensorInfo info : tensorInfoMap.values()) { - counts.put(info.dType, counts.getOrDefault(info.dType, 0) + 1); + for (Map.Entry e : tensorInfoMap.entrySet()) { + if (!e.getKey().endsWith(".qb")) + counts.put(e.getValue().dType, counts.getOrDefault(e.getValue().dType, 0) + 1); } int max = 0; @@ -45,6 +50,16 @@ private DType findDType() { return maxType == DType.BF16 || maxType == DType.F16 ? DType.F32 : maxType; } + @Override + public Map metadata() { + return metadata; + } + + @Override + public Map tensorInfoMap() { + return tensorInfoMap; + } + @Override public AbstractTensor load(String name) throws NoSuchElementException { TensorInfo info = tensorInfoMap.get(name); @@ -64,12 +79,12 @@ public AbstractTensor load(String name) throws NoSuchElementException { switch (info.dType) { case F32: fb = b.asFloatBuffer().slice(); - return new FloatBufferTensor(fb, info.shape, true); + return new FloatBufferTensor(name, fb, info.shape, true); case F16: // If the majority of the weights are F32 then convert to F32 - if (dType == DType.F32) { + if (majorityDType == DType.F32) { len = b.remaining() / DType.F16.size(); - ByteBuffer bb = ByteBuffer.allocateDirect(len * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer bb = ByteBuffer.allocate(len * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < len * DType.F32.size(); i += DType.F32.size()) { short s = b.getShort(); float v = Float.float16ToFloat(s); @@ -78,7 +93,7 @@ public AbstractTensor load(String name) throws NoSuchElementException { return new FloatBufferTensor(bb.asFloatBuffer(), info.shape, true); } else { sb = b.asShortBuffer().slice(); - return new Float16BufferTensor(sb, info.shape, true); + return new Float16BufferTensor(name, sb, info.shape, true); } case BF16: //For now always convert to F32 @@ -89,7 +104,13 @@ public AbstractTensor load(String name) throws NoSuchElementException { float v = FloatConversions.bFloat16ToFloat32(s); fb.put(i, v); } - return new FloatBufferTensor(fb, info.shape, true); + return new FloatBufferTensor(name, fb, info.shape, true); + case Q4: + FloatBufferTensor qb = (FloatBufferTensor) load(name + ".qb"); + return new Q4ByteBufferTensor(name, b.slice(), qb, info.shape, true); + case I8: + FloatBufferTensor qb1 = (FloatBufferTensor) load(name + ".qb"); + return new Q8ByteBufferTensor(name, b.slice(), qb1, info.shape, true); default: throw new IllegalArgumentException("Unsupported Tensor type: " + info.dType.name() + " for " + name); } @@ -97,7 +118,7 @@ public AbstractTensor load(String name) throws NoSuchElementException { @Override public DType getModelDType() { - return dType; + return majorityDType; } @Override @@ -122,4 +143,8 @@ public int hashCode() { return Objects.hash(metadata, tensorInfoMap); } + @Override + public void close() throws Exception { + + } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java index d086f51..a8f828e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java @@ -2,12 +2,20 @@ import com.github.tjake.jlama.safetensors.DType; import com.google.common.base.Preconditions; + +import com.github.tjake.jlama.safetensors.TensorInfo; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.Vector; import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorSpecies; +import java.io.DataOutput; +import java.io.IOException; +import java.io.RandomAccessFile; import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; import java.util.Arrays; /** A Tensor is a multi-dimensional array of data. @@ -225,7 +233,7 @@ void setOwnerCache(TensorCache cache) { public AbstractTensor quantize(DType dType) { - if (this.dims() != 2) + if (this.dims() != 2 || this.dType == dType) return this; return switch (dType) { @@ -237,6 +245,20 @@ public AbstractTensor quantize(DType dType) { }; } + public TensorInfo save(FileChannel out) throws IOException { + ByteBuffer bb = getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + + long startOffset = out.position(); + + out.write(bb); + + long[] lshape = new long[shape.length]; + for (int i = 0; i < shape.length; i++) + lshape[i] = shape[i]; + + return new TensorInfo(dType, lshape, new long[]{startOffset, out.position()}); + } + public void debug(String id) { if (false) { double tmp = 0.0; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java index f778771..21ea740 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/BFloat16BufferTensor.java @@ -13,6 +13,7 @@ import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.safetensors.DType; +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorSpecies; @@ -37,12 +38,11 @@ public BFloat16BufferTensor(int ...shape) { super(DType.BF16, shape, true); this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.segment = Arena.global().allocate(MemoryLayout.sequenceLayout(capacity, ValueLayout.JAVA_SHORT)); - this.b = this.segment.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(capacity * dType().size(), UnsafeDirectByteBuffer.CACHE_LINE_SIZE).asShortBuffer(); } else { this.b = ShortBuffer.allocate(capacity); - this.segment = MemorySegment.ofBuffer(b); } + this.segment = MemorySegment.ofBuffer(b); } public BFloat16BufferTensor(ShortBuffer b, int[] shape, boolean cacheSlices, boolean mmapped) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java index 88e055f..0cfc7fc 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Float16BufferTensor.java @@ -3,6 +3,8 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; + +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorSpecies; @@ -33,19 +35,18 @@ public Float16BufferTensor(int ...shape) { super(DType.F16, shape, true); this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.segment = Arena.global().allocate(MemoryLayout.sequenceLayout(capacity, ValueLayout.JAVA_SHORT)); - this.b = this.segment.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(capacity * dType().size(), UnsafeDirectByteBuffer.CACHE_LINE_SIZE).asShortBuffer(); } else { this.b = ShortBuffer.allocate(capacity); - this.segment = MemorySegment.ofBuffer(b); } + this.segment = MemorySegment.ofBuffer(b); } public Float16BufferTensor(ShortBuffer b, int[] shape, boolean cacheSlices) { this("none", b, shape, cacheSlices); } - private Float16BufferTensor(String name, ShortBuffer b, int[] shape, boolean cacheSlices) { + public Float16BufferTensor(String name, ShortBuffer b, int[] shape, boolean cacheSlices) { super(DType.F16, shape, cacheSlices); Preconditions.checkArgument(b.isDirect(), "Must use direct buffers"); this.name = name; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java index 15b34d4..a7df107 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java @@ -6,16 +6,19 @@ import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sun.nio.ch.DirectBuffer; +import java.io.DataOutput; import java.lang.foreign.Arena; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.Arrays; @@ -31,8 +34,7 @@ * * The Tensor is thread safe for read operations, but not for write operations. */ -public final class FloatBufferTensor extends AbstractTensor -{ +public final class FloatBufferTensor extends AbstractTensor { private static final Logger logger = LoggerFactory.getLogger(FloatBufferTensor.class); private final FloatBuffer b; private final String name; @@ -52,23 +54,37 @@ public FloatBufferTensor(int ...shape) { super(DType.F32, shape, true); this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.segment = Arena.global().allocate(MemoryLayout.sequenceLayout(capacity, ValueLayout.JAVA_FLOAT)); - this.b = segment.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(capacity * dType().size(), UnsafeDirectByteBuffer.CACHE_LINE_SIZE).asFloatBuffer(); } else { this.b = FloatBuffer.allocate(capacity); - this.segment = MemorySegment.ofBuffer(b); } + this.segment = MemorySegment.ofBuffer(b); } public FloatBufferTensor(FloatBuffer b, int[] shape, boolean cacheSlices) { this("none", b, shape, cacheSlices); } - private FloatBufferTensor(String name, FloatBuffer b, int[] shape, boolean cacheSlices) { + public FloatBufferTensor(String name, FloatBuffer b, int[] shape, boolean cacheSlices) { super(DType.F32, shape, cacheSlices); this.name = name; - this.b = b; - this.segment = MemorySegment.ofBuffer(b); + if (TensorOperationsProvider.get().requiresOffHeapTensor()) { + if (b.isDirect()) { + this.b = b; + } else { + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(size() * dType().size(), UnsafeDirectByteBuffer.CACHE_LINE_SIZE).asFloatBuffer(); + this.b.duplicate().put(b); + } + } else { + if (!b.isDirect()) { + this.b = b; + } else { + this.b = FloatBuffer.allocate(size()); + this.b.duplicate().put(b); + } + } + + this.segment = MemorySegment.ofBuffer(this.b); } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java index 69935d6..89c262f 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q4ByteBufferTensor.java @@ -4,6 +4,8 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; + +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorSpecies; @@ -46,12 +48,10 @@ public Q4ByteBufferTensor(AbstractTensor ft) } while (ft.iterate(cursor)); //Process each block in parallel - // ForkJoinPool.commonPool().submit(() -> - IntStream.range(0, startBlockCursors.size()).parallel().forEach((i) -> { - int[] blockStartCursor = startBlockCursors.get(i); - processBlock(ft, blockStartCursor); - }); - //); + IntStream.range(0, startBlockCursors.size()).parallel().forEach((i) -> { + int[] blockStartCursor = startBlockCursors.get(i); + processBlock(ft, blockStartCursor); + }); } void processBlock(AbstractTensor ft, int[] blockStartCursor) { @@ -131,20 +131,34 @@ protected Q4ByteBufferTensor(int[] shape) { this.blockF = new FloatBufferTensor(makeBlockShape(shape)); this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.b = ByteBuffer.allocateDirect(this.size() / 2).order(ByteOrder.LITTLE_ENDIAN); - this.segment = MemorySegment.ofBuffer(b); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(this.size() / 2, UnsafeDirectByteBuffer.CACHE_LINE_SIZE).order(ByteOrder.LITTLE_ENDIAN); } else { this.b = ByteBuffer.allocate(this.size() / 2).order(ByteOrder.LITTLE_ENDIAN); - this.segment = MemorySegment.ofBuffer(b); } + + this.segment = MemorySegment.ofBuffer(b); } public Q4ByteBufferTensor(String name, ByteBuffer b, FloatBufferTensor blockF, int[] shape, boolean cacheSlices) { super(DType.Q4, shape, cacheSlices); - this.name = name; - this.b = b; this.blockF = blockF; - this.segment = MemorySegment.ofBuffer(b); + this.name = name; + if (TensorOperationsProvider.get().requiresOffHeapTensor()) { + if (b.isDirect()) { + this.b = b; + } else { + this.b = ByteBuffer.allocateDirect(b.remaining()).order(ByteOrder.LITTLE_ENDIAN); + this.b.duplicate().put(b); + } + } else { + if (!b.isDirect()) { + this.b = b; + } else { + this.b = ByteBuffer.allocate(b.remaining()).order(ByteOrder.LITTLE_ENDIAN); + this.b.duplicate().put(b); + } + } + this.segment = MemorySegment.ofBuffer(this.b); } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java index 00f744b..cdc7d74 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q5ByteBufferTensor.java @@ -4,6 +4,8 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; + +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; @@ -116,12 +118,12 @@ protected Q5ByteBufferTensor(int[] shape) { this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.b = ByteBuffer.allocateDirect(this.size() / 2).order(ByteOrder.LITTLE_ENDIAN); - this.segment = MemorySegment.ofBuffer(b); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(capacity, UnsafeDirectByteBuffer.CACHE_LINE_SIZE).order(ByteOrder.LITTLE_ENDIAN); } else { this.b = ByteBuffer.allocate(this.size() / 2).order(ByteOrder.LITTLE_ENDIAN); - this.segment = MemorySegment.ofBuffer(b); } + + this.segment = MemorySegment.ofBuffer(b); } public Q5ByteBufferTensor(String name, ByteBuffer b, FloatBufferTensor blockF, int[] b5, int[] shape, boolean cacheSlices) { @@ -139,6 +141,10 @@ protected AbstractTensor make(int... shape) { return new Q5ByteBufferTensor(shape); } + public FloatBufferTensor getBlockF() { + return blockF; + } + @Override protected AbstractTensor make(int offset, int length, int[] shape, boolean cacheSlices) { FloatBufferTensor newBlockF = (FloatBufferTensor) this.blockF.make((int)(offset * I_BLOCK_SIZE), (int)(length * I_BLOCK_SIZE), makeBlockShape(shape), cacheSlices); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java index 8ffa70d..fe55d59 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/Q8ByteBufferTensor.java @@ -4,6 +4,8 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; + +import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorMask; @@ -96,20 +98,34 @@ public Q8ByteBufferTensor(int[] shape) { this.name = "tmp"; if (TensorOperationsProvider.get().requiresOffHeapTensor()) { - this.segment = Arena.global().allocate(MemoryLayout.sequenceLayout(capacity, ValueLayout.JAVA_BYTE)); - this.b = segment.asByteBuffer(); + this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(capacity, UnsafeDirectByteBuffer.CACHE_LINE_SIZE).order(ByteOrder.LITTLE_ENDIAN); } else { - this.b = ByteBuffer.allocate(capacity); - this.segment = MemorySegment.ofBuffer(b); + this.b = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN); } + + this.segment = MemorySegment.ofBuffer(b); } - private Q8ByteBufferTensor(String name, ByteBuffer b, FloatBufferTensor blockF, int[] shape, boolean cacheSlices) { + public Q8ByteBufferTensor(String name, ByteBuffer b, FloatBufferTensor blockF, int[] shape, boolean cacheSlices) { super(DType.I8, shape, cacheSlices); this.name = name; - this.b = b; this.blockF = blockF; - this.segment = MemorySegment.ofBuffer(b); + if (TensorOperationsProvider.get().requiresOffHeapTensor()) { + if (b.isDirect()) { + this.b = b; + } else { + this.b = ByteBuffer.allocateDirect(b.remaining()).order(ByteOrder.LITTLE_ENDIAN); + this.b.duplicate().put(b); + } + } else { + if (!b.isDirect()) { + this.b = b; + } else { + this.b = ByteBuffer.allocate(b.remaining()).order(ByteOrder.LITTLE_ENDIAN); + this.b.duplicate().put(b); + } + } + this.segment = MemorySegment.ofBuffer(this.b); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 4cb94b0..79d74cc 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -45,7 +45,7 @@ public String name() { @Override public boolean requiresOffHeapTensor() { - return true; + return false; } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java index 16e7c43..045c3de 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java @@ -28,10 +28,10 @@ public static TensorOperations get() { private final TensorOperations provider; private TensorOperationsProvider() { - this.provider = pickFastestImplementaion(); + this.provider = pickFastestImplementation(); } - private TensorOperations pickFastestImplementaion() { + private TensorOperations pickFastestImplementation() { TensorOperations pick = null; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/UnsafeDirectByteBuffer.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/UnsafeDirectByteBuffer.java new file mode 100644 index 0000000..c7a8091 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/UnsafeDirectByteBuffer.java @@ -0,0 +1,97 @@ +package com.github.tjake.jlama.util; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.jctools.util.UnsafeAccess; + +public class UnsafeDirectByteBuffer +{ + private static final long addressOffset; + public static final int CACHE_LINE_SIZE = 64; + + public static final int PAGE_SIZE = UnsafeAccess.UNSAFE.pageSize(); + + static { + try { + addressOffset = UnsafeAccess.UNSAFE.objectFieldOffset(Buffer.class + .getDeclaredField("address")); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static long getAddress(ByteBuffer buffy) { + return UnsafeAccess.UNSAFE.getLong(buffy, addressOffset); + } + + /** + * put byte and skip position update and boundary checks + */ + public static void putByte(long address, int position, byte b) { + UnsafeAccess.UNSAFE.putByte(address + (position << 0), b); + } + + public static void putByte(long address, byte b) { + UnsafeAccess.UNSAFE.putByte(address, b); + } + + public static ByteBuffer allocateAlignedByteBuffer(int capacity, long align) + { + if (Long.bitCount(align) != 1) { + throw new IllegalArgumentException("Alignment must be a power of 2"); + } + // We over allocate by the alignment so we know we can have a large + // enough aligned block of memory to use. + ByteBuffer buffy = ByteBuffer.allocateDirect((int) (capacity + align)); + long address = getAddress(buffy); + if ((address & (align - 1)) == 0) { + // limit to the capacity specified + buffy.limit(capacity); + // set order to native while we are here. + ByteBuffer slice = buffy.slice().order(ByteOrder.nativeOrder()); + // the slice is now an aligned buffer of the required capacity + return slice; + } else { + int newPosition = (int) (align - (address & (align - 1))); + buffy.position(newPosition); + int newLimit = newPosition + capacity; + // limit to the capacity specified + buffy.limit(newLimit); + // set order to native while we are here. + ByteBuffer slice = buffy.slice().order(ByteOrder.nativeOrder()); + // the slice is now an aligned buffer of the required capacity + return slice; + } + } + + public static boolean isPageAligned(ByteBuffer buffy) { + return isPageAligned(getAddress(buffy)); + } + + /** + * This assumes cache line is 64b + */ + public static boolean isCacheAligned(ByteBuffer buffy) { + return isCacheAligned(getAddress(buffy)); + } + + public static boolean isPageAligned(long address) { + return (address & (PAGE_SIZE - 1)) == 0; + } + + /** + * This assumes cache line is 64b + */ + public static boolean isCacheAligned(long address) { + return (address & (CACHE_LINE_SIZE - 1)) == 0; + } + + public static boolean isAligned(long address, long align) { + if (Long.bitCount(align) != 1) { + throw new IllegalArgumentException("Alignment must be a power of 2"); + } + return (address & (align - 1)) == 0; + } +} \ No newline at end of file diff --git a/jlama-native/src/main/c/vector_simd.c b/jlama-native/src/main/c/vector_simd.c index dfc2e5f..a146c24 100644 --- a/jlama-native/src/main/c/vector_simd.c +++ b/jlama-native/src/main/c/vector_simd.c @@ -28,7 +28,7 @@ float dot_product_f32_q8_256(const float* a, int aoffset, const float *bf, const __m256 vb_f32 = _mm256_set1_ps(*(bf + bf_idx)); // Load float32 - __m256 va = _mm256_load_ps(a + ao); + __m256 va = _mm256_loadu_ps(a + ao); // Load 8 bytes into a 128-bit integer register __m128i int_vb = _mm_loadu_si128((__m128i const*)(b + bo)); @@ -74,7 +74,7 @@ float dot_product_f32_q8_512(const float* a, int aoffset, const float *bf, const __m512 vb_f32 = _mm512_set1_ps(*(bf + bf_idx)); // Load float32 - __m512 va = _mm512_load_ps(a + ao); + __m512 va = _mm512_loadu_ps(a + ao); // Load 16 bytes into a 256-bit integer register __m128i int_vb = _mm_loadu_si128((__m128i const*)(b + bo)); @@ -129,8 +129,8 @@ float dot_product_f32_q8_128(const float* a, int aoffset, const float *bf, const __m128 vb_f32 = _mm_set1_ps(*(bf + bf_idx)); // Load float32 - __m128 va0 = _mm_load_ps(a + ao); - __m128 va1 = _mm_load_ps(a + ao + 4); + __m128 va0 = _mm_loadu_ps(a + ao); + __m128 va1 = _mm_loadu_ps(a + ao + 4); // Load 8 bytes into a 128-bit integer register __m128i int_vb0 = _mm_loadu_si32((__m128i const*)(b + bo)); @@ -309,10 +309,10 @@ float dot_product_f32_q4_256(const float* a, int aoffset, const float *bf, const __m256 vb_f32 = _mm256_set1_ps(*(bf + bf_idx)); // Load float32 - __m256 va0 = _mm256_load_ps(a + ao); - __m256 va1 = _mm256_load_ps(a + ao + 8); - __m256 va2 = _mm256_load_ps(a + ao + 8 + 8); - __m256 va3 = _mm256_load_ps(a + ao + 8 + 8 + 8); + __m256 va0 = _mm256_loadu_ps(a + ao); + __m256 va1 = _mm256_loadu_ps(a + ao + 8); + __m256 va2 = _mm256_loadu_ps(a + ao + 8 + 8); + __m256 va3 = _mm256_loadu_ps(a + ao + 8 + 8 + 8); // Load 8 bytes into a 128-bit integer register __m128i int_vb0 = _mm_loadl_epi64((__m128i const*)(b + bo)); // Load lower 64 bits @@ -395,11 +395,11 @@ float dot_product_f32_q4_512(const float* a, int aoffset, const float *bf, const __m512 vb_f32 = _mm512_set1_ps(*(bf + bf_idx)); // Load float32 - __m512 va0 = _mm512_load_ps(a + ao); - __m512 va1 = _mm512_load_ps(a + ao + 16); + __m512 va0 = _mm512_loadu_ps(a + ao); + __m512 va1 = _mm512_loadu_ps(a + ao + 16); // Load 8 bytes into a 128-bit integer register - __m128i int_vb0 = _mm_load_si128((__m128i const*)(b + bo)); // Load 128 bits + __m128i int_vb0 = _mm_loadu_si128((__m128i const*)(b + bo)); // Load 128 bits // Mask to keep the first 4 bits of each byte __m128i mask_first_4bits = _mm_set1_epi8(0xF); @@ -473,14 +473,14 @@ float dot_product_f32_q4_128(const float* a, int aoffset, const float *bf, const __m128 vb_f32 = _mm_set1_ps(*(bf + bf_idx)); // Load float32 - __m128 va0 = _mm_load_ps(a + ao); - __m128 va1 = _mm_load_ps(a + ao + 4); - __m128 va2 = _mm_load_ps(a + ao + 4 + 4); - __m128 va3 = _mm_load_ps(a + ao + 4 + 4 + 4); - __m128 va4 = _mm_load_ps(a + ao + 4 + 4 + 4 + 4); - __m128 va5 = _mm_load_ps(a + ao + 4 + 4 + 4 + 4 + 4); - __m128 va6 = _mm_load_ps(a + ao + 4 + 4 + 4 + 4 + 4 + 4); - __m128 va7 = _mm_load_ps(a + ao + 4 + 4 + 4 + 4 + 4 + 4 + 4); + __m128 va0 = _mm_loadu_ps(a + ao); + __m128 va1 = _mm_loadu_ps(a + ao + 4); + __m128 va2 = _mm_loadu_ps(a + ao + 4 + 4); + __m128 va3 = _mm_loadu_ps(a + ao + 4 + 4 + 4); + __m128 va4 = _mm_loadu_ps(a + ao + 4 + 4 + 4 + 4); + __m128 va5 = _mm_loadu_ps(a + ao + 4 + 4 + 4 + 4 + 4); + __m128 va6 = _mm_loadu_ps(a + ao + 4 + 4 + 4 + 4 + 4 + 4); + __m128 va7 = _mm_loadu_ps(a + ao + 4 + 4 + 4 + 4 + 4 + 4 + 4); // Load 8 bytes into a 128-bit integer register __m128i int_vb0 = _mm_loadu_si32((__m128i const*)(b + bo)); @@ -712,7 +712,7 @@ float dot_product_q8_q4_512(const float *af, const char* a, int aoffset, const f __m512 scale_f32 = _mm512_set1_ps(scalef[j]); // Load 8 bytes into a 128-bit integer register - __m128i int_vb0 = _mm_load_si128((__m128i const*)(b + bo)); // Load 128 bits + __m128i int_vb0 = _mm_loadu_si128((__m128i const*)(b + bo)); // Load 128 bits // Masked values __m128i first_4bits0 = _mm_and_si128(int_vb0, mask_first_4bits); @@ -730,8 +730,8 @@ float dot_product_q8_q4_512(const float *af, const char* a, int aoffset, const f __m512i int_vb_ext_hi0 = _mm512_cvtepi8_epi32(last_4bits0); // Load 16 bytes into 2 128-bit integer registers - __m128i int_va0 = _mm_load_si128((__m128i const*)(a + ao)); - __m128i int_va1 = _mm_load_si128((__m128i const*)(a + ao + 16)); + __m128i int_va0 = _mm_loadu_si128((__m128i const*)(a + ao)); + __m128i int_va1 = _mm_loadu_si128((__m128i const*)(a + ao + 16)); //Extend to 32-bit ints __m512i int_va0_ext = _mm512_cvtepi8_epi32(int_va0); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java index 98928a0..31e4b75 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java @@ -15,6 +15,7 @@ import com.github.tjake.jlama.model.llama.LlamaTokenizer; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; +import org.junit.Assert; import org.junit.Assume; import org.junit.Test; import org.slf4j.Logger; @@ -26,6 +27,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; @@ -66,15 +69,52 @@ public void GPT2Run() throws IOException { public void LlamaRun() throws Exception { String modelPrefix = "models/Llama-2-7b-chat-hf"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); - try (SafeTensorIndex weights = SafeTensorIndex.loadWithWeights(Path.of(modelPrefix))) { + try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.of(DType.Q4)); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty()); String prompt = "Simply put, the theory of relativity states that"; model.generate(prompt, 0.7f, 256, false, makeOutHandler()); } } + + @Test + public void testQuantize() throws Exception { + String modelPrefix = "models/Llama-2-7b-chat-hf"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + + Path tmpOut = Files.createTempDirectory("jltest"); + try + { + Path out = SafeTensorSupport.quantizeModel(Paths.get(modelPrefix), DType.Q4, new String[]{ + "model.embed_tokens.weight", "lm_head.weight", + }, Optional.of(tmpOut)); + + Assert.assertEquals(tmpOut, out); + + WeightLoader weights = SafeTensorSupport.loadWeights(tmpOut.toFile()); + LlamaTokenizer tokenizer = new LlamaTokenizer(tmpOut); + Config c = om.readValue(new File(tmpOut + "/config.json"), LlamaConfig.class); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty()); + + String prompt = "Lily picked up a flower and gave it to"; + model.generate(prompt, 0.7f, 128, false, makeOutHandler()); + } + finally + { + Arrays.stream(Objects.requireNonNull(tmpOut.toFile().listFiles())).forEach(f -> { + try { + Files.delete(f.toPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + Files.deleteIfExists(tmpOut); + } + } + @Test public void TinyLlamaRun() throws Exception { String modelPrefix = "models/TinyLLama";