Skip to content

Commit

Permalink
Merge pull request #82 from tjake/qwen2
Browse files Browse the repository at this point in the history
Add Qwen2 support and fix bug with small models using I8Q4
  • Loading branch information
tjake authored Oct 20, 2024
2 parents 47e7a92 + a806fd7 commit 1754cf2
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.MachineSpec;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

Expand All @@ -44,6 +45,7 @@
import java.util.function.Consumer;
import java.util.stream.Collectors;

import jdk.incubator.vector.FloatVector;
import net.jafama.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -121,6 +123,13 @@ protected AbstractModel(
workingMemoryQType = DType.BF16;
}

// Check to make sure the model is big enough to support Q4I8 computations
// If not, fall back to F32
if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 &&
(c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0){
workingMemoryQType = DType.F32;
}

if (workingMemoryQType != workingMemoryDType) {
boolean supportsQType;
AbstractTensor tmp = makeDenseTensor(Q8ByteBufferTensor.BLOCK_SIZE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import com.github.tjake.jlama.model.mistral.MistralModel;
import com.github.tjake.jlama.model.mixtral.MixtralConfig;
import com.github.tjake.jlama.model.mixtral.MixtralModel;
import com.github.tjake.jlama.model.qwen2.Qwen2Config;
import com.github.tjake.jlama.model.qwen2.Qwen2Model;
import com.github.tjake.jlama.model.qwen2.Qwen2Tokenizer;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
Expand All @@ -57,7 +60,8 @@ public enum ModelType {
MIXTRAL(MixtralModel.class, MixtralConfig.class, LlamaTokenizer.class),
LLAMA(LlamaModel.class, LlamaConfig.class, LlamaTokenizer.class),
GPT2(GPT2Model.class, GPT2Config.class, GPT2Tokenizer.class),
BERT(BertModel.class, BertConfig.class, BertTokenizer.class);
BERT(BertModel.class, BertConfig.class, BertTokenizer.class),
QWEN2(Qwen2Model.class, Qwen2Config.class, Qwen2Tokenizer.class);

public final Class<? extends AbstractModel> modelClass;
public final Class<? extends Config> configClass;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.github.tjake.jlama.model.qwen2;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;

public class Qwen2Config extends Config {

@JsonCreator
public Qwen2Config(
@JsonProperty("max_position_embeddings") int contextLength,
@JsonProperty("hidden_size") int embeddingLength,
@JsonProperty("intermediate_size") int hiddenLength,
@JsonProperty("num_attention_heads") int numberOfHeads,
@JsonProperty("num_key_value_heads") int numberOfKeyValueHeads,
@JsonProperty("num_hidden_layers") int numberOfLayers,
@JsonProperty("rms_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") int eosToken,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("rope_theta") Double ropeTheta) {
super(
contextLength,
embeddingLength,
hiddenLength,
numberOfHeads,
numberOfKeyValueHeads,
numberOfLayers,
layerNormEps,
vocabularySize,
bosToken,
List.of(eosToken),
activationFunction,
ropeTheta,
1.0,
null,
embeddingLength / numberOfHeads
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package com.github.tjake.jlama.model.qwen2;

import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.llama.LlamaModel;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Optional;
import java.util.stream.IntStream;

public class Qwen2Model extends LlamaModel {

private static final Logger logger = LoggerFactory.getLogger(Qwen2Model.class);

public Qwen2Model(
Config config,
WeightLoader weights,
Tokenizer tokenizer,
DType workingDType,
DType workingQType,
Optional<DType> modelQType
) {
super(config, weights, tokenizer, workingDType, workingQType, modelQType);
}

public Qwen2Model(
InferenceType inferenceType,
Config config,
WeightLoader weights,
Tokenizer tokenizer,
DType workingDType,
DType workingQType,
Optional<DType> modelQType
) {
super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType);
}

@Override
protected TransformerBlock[] loadTransformerBlockWeights() {
DType qType = modelQType.orElse(this.modelDType);
if (qType != this.modelDType) {
logger.info("Quantizing model with {} - Please hold...", qType);
}

TransformerBlock[] transformerBlocks = new TransformerBlock[c.dctx().numberOfLayers];

IntStream.range(c.dctx().layerStart, c.dctx().layerEnd).parallel().forEach(i -> {

int relativeLayer = i - c.dctx().layerStart; // FIXME: add a helper to the context

String base = "model.layers." + i + ".";
String prefix = base + "self_attn.";
CausalSelfAttention attention = new CausalSelfAttention(
this,
relativeLayer,
Optional.of(weights.load(prefix + "q_proj.bias").quantize(qType)),
Optional.of(weights.load(prefix + "k_proj.bias").quantize(qType)),
Optional.of(weights.load(prefix + "v_proj.bias").quantize(qType)),
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
Optional.empty(),
weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType)
);

prefix = base + "mlp.";

MLPBlock mlp = new MLPBlock(
this,
c.activationFunction,
weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1
weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2
weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType)
); // w3

transformerBlocks[relativeLayer] = new TransformerBlock(
this,
relativeLayer,
new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType)),
attention,
new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType)),
mlp
);
});

return transformerBlocks;
}


@Override
public ModelSupport.ModelType getModelType() {
return ModelSupport.ModelType.QWEN2;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.github.tjake.jlama.model.qwen2;

import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer;

import java.nio.file.Path;
import java.util.Optional;
import java.util.stream.Collectors;

public class Qwen2Tokenizer extends BPETokenizer {


public Qwen2Tokenizer(Path modelRoot) {
super(modelRoot);
}

@Override
protected String preProcess(String sentence) {
if (model.normalizer() != null) sentence = model.normalizer().normalize(sentence);

if (model.isLegacy() && !model.byteFallback) {
sentence = sentence.codePoints()
.map(c -> alteredBytes.getOrDefault(c, c))
.mapToObj(Character::toString)
.collect(Collectors.joining());
}

return sentence;
}

@Override
protected long encodeCharacterAsToken(byte c) {
return Byte.toUnsignedLong(c);
}

@Override
protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {
return Optional.empty();
}

@Override
protected String postProcessToken(String decoded) {
if (decoded == null) decoded = model.unkToken;

if (model.isLegacy() && !model.byteFallback) {
decoded = decoded.codePoints()
.map(c -> alteredBytes.inverse().getOrDefault(c, c))
.mapToObj(Character::toString)
.collect(Collectors.joining());
}

return decoded;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ protected AbstractTensor make(int offset, int length, TensorShape shape, boolean
public float get(int... dims) {
Preconditions.checkArgument(dims.length <= shape.dims(), "Too many dimensions specified");
Preconditions.checkArgument(dims.length == shape.dims(), "Must specify all dimensions");
return b.hasArray() ? b.array()[b.arrayOffset() + getOffset(dims)] : b.get(getOffset(dims));
return b.get(getOffset(dims));
}

@Override
Expand Down Expand Up @@ -136,18 +136,14 @@ public int getMemorySegmentOffset(int offset) {
@Override
public FloatVector getVector(VectorSpecies<Float> species, int... voffset) {
int offset = getOffset(voffset);
if (b.hasArray()) return FloatVector.fromArray(species, b.array(), offset);

return FloatVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
public void intoTensor(FloatVector vector, int... aoffset) {
// Preconditions.checkArgument(!b.isReadOnly());
int offset = getOffset(aoffset);

if (b.hasArray()) vector.intoArray(b.array(), offset);
else vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class TestModels {

static {
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
// System.setProperty("jlama.force_panama_tensor_operations", "true");
//System.setProperty("jlama.force_panama_tensor_operations", "true");
}

private static final Logger logger = LoggerFactory.getLogger(TestModels.class);
Expand All @@ -69,14 +69,26 @@ public void GPT2Run() throws IOException {
gpt2.generate(UUID.randomUUID(), prompt, 0.8f, 256, makeOutHandler());
}

@Test
public void Qwen2Run() throws IOException {
String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

AbstractModel qwen2 = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.I8);
PromptContext prompt = qwen2.promptSupport().get().builder().addUserMessage("What is the capital of France?").build();

Generator.Response r = qwen2.generate(UUID.randomUUID(), prompt, 0.9f, 1024, makeOutHandler());
logger.info("Response: {}", r);
}

@Test
public void LlamaRun() throws Exception {
String modelPrefix = "../models/tjake_Llama-3.2-1B-Instruct-Jlama-Q4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));
try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) {
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty());
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty());

PromptSupport.Builder builder = model.promptSupport().get().builder();

Expand Down

0 comments on commit 1754cf2

Please sign in to comment.