From f7219800af2480064ab707119e3099f3f4ac790d Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sat, 21 Oct 2023 21:31:36 -0400 Subject: [PATCH] Fix cli as contructor changed --- .../tjake/jlama/cli/commands/ModelBaseCommand.java | 5 +++-- .../com/github/tjake/jlama/model/AbstractModel.java | 4 ---- .../com/github/tjake/jlama/model/bert/BertModel.java | 4 ++-- .../com/github/tjake/jlama/model/gpt2/GPT2Model.java | 4 ++-- .../github/tjake/jlama/model/llama/LlamaModel.java | 8 ++++---- .../com/github/tjake/jlama/safetensors/Weights.java | 2 +- .../tensor/operations/NativeTensorOperations.java | 12 ++++++------ .../com/github/tjake/jlama/models/TestModels.java | 9 +++++---- 8 files changed, 23 insertions(+), 25 deletions(-) diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java index 4dd422c..6c94fb9 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java @@ -6,6 +6,7 @@ import java.lang.reflect.InvocationTargetException; import java.nio.file.Path; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; @@ -91,8 +92,8 @@ protected AbstractModel loadModel(File model) { Tokenizer t = modelType.tokenizerClass.getConstructor(Path.class).newInstance(baseDir.toPath()); WeightLoader wl = SafeTensorSupport.loadWeights(baseDir); - return modelType.modelClass.getConstructor(Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class) - .newInstance(c, wl, t, workingMemoryType, workingQuantizationType); + return modelType.modelClass.getConstructor(Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class, Optional.class) + .newInstance(c, wl, t, workingMemoryType, workingQuantizationType, Optional.ofNullable(modelQuantization)); } catch (IOException | NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 1bff3ef..bdbe17d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -37,10 +37,6 @@ public abstract class AbstractModel { protected final Optional modelQType; private static final ThreadLocal tmpArray = new ThreadLocal<>(); - protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType) - { - this(c, w, t, workingMemoryDType, workingMemoryQType, Optional.empty()); - } protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType, Optional modelQType) { this.c = c; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index 4dab730..292f676 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -22,8 +22,8 @@ public class BertModel extends AbstractModel { private final TransformerBlock[] transformerBlocks; - public BertModel(Config c, Weights w, Tokenizer tokenizer, DType workingDType, DType workingQType) { - super(c, w, tokenizer, workingDType, workingQType); + public BertModel(Config c, Weights w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional modelQType) { + super(c, w, tokenizer, workingDType, workingQType, modelQType); this.we = w.load("embeddings.word_embeddings.weight"); this.wte = w.load("embeddings.token_type_embeddings.weight"); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java index 980fd22..3614763 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Model.java @@ -15,8 +15,8 @@ public class GPT2Model extends AbstractModel { private final TransformerBlock[] transformerBlocks; - public GPT2Model(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType) { - super(c, w, tokenizer, workingDType, workingQType); + public GPT2Model(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional modelQType) { + super(c, w, tokenizer, workingDType, workingQType, modelQType); this.wte = w.load("wte.weight"); this.wpe = w.load("wpe.weight"); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index 4a71e95..0acf422 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -29,12 +29,12 @@ public class LlamaModel extends AbstractModel { private final AbstractTensor classificationWeights; - public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, DType modelQType) { - super(config, weights, tokenizer, workingDType, workingQType, Optional.ofNullable(modelQType)); + public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional modelQType) { + super(config, weights, tokenizer, workingDType, workingQType, modelQType); - DType qType = modelQType != null ? modelQType : this.modelDType; + DType qType = modelQType.orElse(this.modelDType); - if (modelQType != this.modelDType) { + if (qType != this.modelDType) { logger.info("Quantizing model with {} - Please hold...", qType); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java index d6924a9..88b8195 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java @@ -42,7 +42,7 @@ private DType findDType() { } //FIXME don't really support B16 atm - return maxType == DType.BF16 ? DType.F32 : maxType; + return maxType == DType.BF16 || maxType == DType.F16 ? DType.F32 : maxType; } @Override diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java index 037b9ad..0a8bc76 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java @@ -70,14 +70,14 @@ public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int bof case F32 -> NativeSimd.dot_product_f32(flags, a.getMemorySegment(), aoffset, b.getMemorySegment(), boffset, limit); case I8 -> NativeSimd.dot_product_f32_q8(flags, a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); case Q4 -> NativeSimd.dot_product_f32_q4(flags, a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); - default -> throw new UnsupportedOperationException(); + default -> throw new UnsupportedOperationException(b.dType().name()); }; case I8 -> switch (b.dType()) { case Q4 -> NativeSimd.dot_product_q8_q4(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); //case I8 -> NativeSimd.dot_product_q8(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit); - default -> throw new UnsupportedOperationException(); + default -> throw new UnsupportedOperationException(b.dType().name()); }; - default -> throw new UnsupportedOperationException(); + default -> throw new UnsupportedOperationException(a.dType().name()); }; } @@ -88,16 +88,16 @@ public void dotProductChunk(AbstractTensor r, AbstractTensor a, AbstractTensor b case F32: NativeSimd.dot_product_f32_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break; case I8: NativeSimd.dot_product_f32_q8_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break; case Q4: NativeSimd.dot_product_f32_q4_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break; - default: throw new UnsupportedOperationException(); + default: throw new UnsupportedOperationException(b.dType().name()); } break; case I8: switch (b.dType()) { case Q4: NativeSimd.dot_product_q8_q4_chunked(flags, r.getMemorySegment(), ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), 0, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break; - default: throw new UnsupportedOperationException(); + default: throw new UnsupportedOperationException(b.dType().name()); } break; - default: throw new UnsupportedOperationException(); + default: throw new UnsupportedOperationException(a.dType().name()); } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java index 993589a..aae7af6 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/models/TestModels.java @@ -25,6 +25,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; @@ -51,7 +52,7 @@ public void GPT2Run() throws IOException { Weights v = SafeTensorSupport.readWeights(bb); Tokenizer tokenizer = new GPT2Tokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), GPT2Config.class); - GPT2Model gpt2 = new GPT2Model(c, v, tokenizer, DType.F32, DType.F32); + GPT2Model gpt2 = new GPT2Model(c, v, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32)); String prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " + "previously unexplored valley, in the Andes Mountains. " + @@ -67,7 +68,7 @@ public void LlamaRun() throws Exception { try (SafeTensorIndex weights = SafeTensorIndex.loadWithWeights(Path.of(modelPrefix))) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, DType.Q4); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.of(DType.Q4)); String prompt = "Simply put, the theory of relativity states that"; model.generate(prompt, 0.7f, 256, false, makeOutHandler()); } @@ -84,7 +85,7 @@ public void TinyLlamaRun() throws Exception { Weights weights = SafeTensorSupport.readWeights(bb); LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, DType.F32); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32)); String prompt = "Lily picked up a flower and gave it to"; model.generate(prompt, 0.7f, 128, false, makeOutHandler()); @@ -102,7 +103,7 @@ public void BertRun() throws Exception { Weights weights = SafeTensorSupport.readWeights(bb); Tokenizer tokenizer = new BertTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), BertConfig.class); - BertModel model = new BertModel(c, weights, tokenizer, DType.F32, DType.F32); + BertModel model = new BertModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32)); String base = "A man is eating food."; String[] examples = new String[]{