diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 25a3004..9f9585a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -33,6 +33,7 @@ import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.JsonSupport; +import com.github.tjake.jlama.util.MachineSpec; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; @@ -44,6 +45,7 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import jdk.incubator.vector.FloatVector; import net.jafama.FastMath; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,6 +123,13 @@ protected AbstractModel( workingMemoryQType = DType.BF16; } + // Check to make sure the model is big enough to support Q4I8 computations + // If not, fall back to F32 + if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 && + (c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0){ + workingMemoryQType = DType.F32; + } + if (workingMemoryQType != workingMemoryDType) { boolean supportsQType; AbstractTensor tmp = makeDenseTensor(Q8ByteBufferTensor.BLOCK_SIZE); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java index c3f0efa..bd1aa5f 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java @@ -33,6 +33,9 @@ import com.github.tjake.jlama.model.mistral.MistralModel; import com.github.tjake.jlama.model.mixtral.MixtralConfig; import com.github.tjake.jlama.model.mixtral.MixtralModel; +import com.github.tjake.jlama.model.qwen2.Qwen2Config; +import com.github.tjake.jlama.model.qwen2.Qwen2Model; +import com.github.tjake.jlama.model.qwen2.Qwen2Tokenizer; import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.SafeTensorSupport; @@ -57,7 +60,8 @@ public enum ModelType { MIXTRAL(MixtralModel.class, MixtralConfig.class, LlamaTokenizer.class), LLAMA(LlamaModel.class, LlamaConfig.class, LlamaTokenizer.class), GPT2(GPT2Model.class, GPT2Config.class, GPT2Tokenizer.class), - BERT(BertModel.class, BertConfig.class, BertTokenizer.class); + BERT(BertModel.class, BertConfig.class, BertTokenizer.class), + QWEN2(Qwen2Model.class, Qwen2Config.class, Qwen2Tokenizer.class); public final Class modelClass; public final Class configClass; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java new file mode 100644 index 0000000..33661ab --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Config.java @@ -0,0 +1,44 @@ +package com.github.tjake.jlama.model.qwen2; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.tjake.jlama.math.ActivationFunction; +import com.github.tjake.jlama.safetensors.Config; + +import java.util.List; + +public class Qwen2Config extends Config { + + @JsonCreator + public Qwen2Config( + @JsonProperty("max_position_embeddings") int contextLength, + @JsonProperty("hidden_size") int embeddingLength, + @JsonProperty("intermediate_size") int hiddenLength, + @JsonProperty("num_attention_heads") int numberOfHeads, + @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, + @JsonProperty("num_hidden_layers") int numberOfLayers, + @JsonProperty("rms_norm_eps") float layerNormEps, + @JsonProperty("vocab_size") int vocabularySize, + @JsonProperty("bos_token_id") int bosToken, + @JsonProperty("eos_token_id") int eosToken, + @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, + @JsonProperty("rope_theta") Double ropeTheta) { + super( + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + List.of(eosToken), + activationFunction, + ropeTheta, + 1.0, + null, + embeddingLength / numberOfHeads + ); + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java new file mode 100644 index 0000000..9e6e3ea --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Model.java @@ -0,0 +1,98 @@ +package com.github.tjake.jlama.model.qwen2; + +import com.github.tjake.jlama.model.*; +import com.github.tjake.jlama.model.llama.LlamaModel; +import com.github.tjake.jlama.safetensors.Config; +import com.github.tjake.jlama.safetensors.DType; +import com.github.tjake.jlama.safetensors.WeightLoader; +import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Optional; +import java.util.stream.IntStream; + +public class Qwen2Model extends LlamaModel { + + private static final Logger logger = LoggerFactory.getLogger(Qwen2Model.class); + + public Qwen2Model( + Config config, + WeightLoader weights, + Tokenizer tokenizer, + DType workingDType, + DType workingQType, + Optional modelQType + ) { + super(config, weights, tokenizer, workingDType, workingQType, modelQType); + } + + public Qwen2Model( + InferenceType inferenceType, + Config config, + WeightLoader weights, + Tokenizer tokenizer, + DType workingDType, + DType workingQType, + Optional modelQType + ) { + super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType); + } + + @Override + protected TransformerBlock[] loadTransformerBlockWeights() { + DType qType = modelQType.orElse(this.modelDType); + if (qType != this.modelDType) { + logger.info("Quantizing model with {} - Please hold...", qType); + } + + TransformerBlock[] transformerBlocks = new TransformerBlock[c.dctx().numberOfLayers]; + + IntStream.range(c.dctx().layerStart, c.dctx().layerEnd).parallel().forEach(i -> { + + int relativeLayer = i - c.dctx().layerStart; // FIXME: add a helper to the context + + String base = "model.layers." + i + "."; + String prefix = base + "self_attn."; + CausalSelfAttention attention = new CausalSelfAttention( + this, + relativeLayer, + Optional.of(weights.load(prefix + "q_proj.bias").quantize(qType)), + Optional.of(weights.load(prefix + "k_proj.bias").quantize(qType)), + Optional.of(weights.load(prefix + "v_proj.bias").quantize(qType)), + weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), + Optional.empty(), + weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType) + ); + + prefix = base + "mlp."; + + MLPBlock mlp = new MLPBlock( + this, + c.activationFunction, + weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1 + weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2 + weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType) + ); // w3 + + transformerBlocks[relativeLayer] = new TransformerBlock( + this, + relativeLayer, + new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType)), + attention, + new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType)), + mlp + ); + }); + + return transformerBlocks; + } + + + @Override + public ModelSupport.ModelType getModelType() { + return ModelSupport.ModelType.QWEN2; + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java new file mode 100644 index 0000000..76be12c --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/qwen2/Qwen2Tokenizer.java @@ -0,0 +1,53 @@ +package com.github.tjake.jlama.model.qwen2; + +import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; + +import java.nio.file.Path; +import java.util.Optional; +import java.util.stream.Collectors; + +public class Qwen2Tokenizer extends BPETokenizer { + + + public Qwen2Tokenizer(Path modelRoot) { + super(modelRoot); + } + + @Override + protected String preProcess(String sentence) { + if (model.normalizer() != null) sentence = model.normalizer().normalize(sentence); + + if (model.isLegacy() && !model.byteFallback) { + sentence = sentence.codePoints() + .map(c -> alteredBytes.getOrDefault(c, c)) + .mapToObj(Character::toString) + .collect(Collectors.joining()); + } + + return sentence; + } + + @Override + protected long encodeCharacterAsToken(byte c) { + return Byte.toUnsignedLong(c); + } + + @Override + protected Optional maybeDecodeTokenAsCharacter(long id) { + return Optional.empty(); + } + + @Override + protected String postProcessToken(String decoded) { + if (decoded == null) decoded = model.unkToken; + + if (model.isLegacy() && !model.byteFallback) { + decoded = decoded.codePoints() + .map(c -> alteredBytes.inverse().getOrDefault(c, c)) + .mapToObj(Character::toString) + .collect(Collectors.joining()); + } + + return decoded; + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java index 21044f3..a3b2783 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java @@ -104,7 +104,7 @@ protected AbstractTensor make(int offset, int length, TensorShape shape, boolean public float get(int... dims) { Preconditions.checkArgument(dims.length <= shape.dims(), "Too many dimensions specified"); Preconditions.checkArgument(dims.length == shape.dims(), "Must specify all dimensions"); - return b.hasArray() ? b.array()[b.arrayOffset() + getOffset(dims)] : b.get(getOffset(dims)); + return b.get(getOffset(dims)); } @Override @@ -136,8 +136,6 @@ public int getMemorySegmentOffset(int offset) { @Override public FloatVector getVector(VectorSpecies species, int... voffset) { int offset = getOffset(voffset); - if (b.hasArray()) return FloatVector.fromArray(species, b.array(), offset); - return FloatVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); } @@ -145,9 +143,7 @@ public FloatVector getVector(VectorSpecies species, int... voffset) { public void intoTensor(FloatVector vector, int... aoffset) { // Preconditions.checkArgument(!b.isReadOnly()); int offset = getOffset(aoffset); - - if (b.hasArray()) vector.intoArray(b.array(), offset); - else vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); + vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); } @Override diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index a2c327d..84dc397 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -49,7 +49,7 @@ public class TestModels { static { System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0"); - // System.setProperty("jlama.force_panama_tensor_operations", "true"); + //System.setProperty("jlama.force_panama_tensor_operations", "true"); } private static final Logger logger = LoggerFactory.getLogger(TestModels.class); @@ -69,6 +69,18 @@ public void GPT2Run() throws IOException { gpt2.generate(UUID.randomUUID(), prompt, 0.8f, 256, makeOutHandler()); } + @Test + public void Qwen2Run() throws IOException { + String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + + AbstractModel qwen2 = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.I8); + PromptContext prompt = qwen2.promptSupport().get().builder().addUserMessage("What is the capital of France?").build(); + + Generator.Response r = qwen2.generate(UUID.randomUUID(), prompt, 0.9f, 1024, makeOutHandler()); + logger.info("Response: {}", r); + } + @Test public void LlamaRun() throws Exception { String modelPrefix = "../models/tjake_Llama-3.2-1B-Instruct-Jlama-Q4"; @@ -76,7 +88,7 @@ public void LlamaRun() throws Exception { try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty()); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); PromptSupport.Builder builder = model.promptSupport().get().builder();