Skip to content

Commit

Permalink
Merge pull request #58 from tjake/classifier-support
Browse files Browse the repository at this point in the history
Add classifier support
  • Loading branch information
tjake authored Sep 22, 2024
2 parents e190ac2 + d820ae2 commit 2c3f080
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 53 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Implements:
* Paged Attention
* Mixture of Experts
* Tool Calling
* Generate Embeddings
* Classifier Support
* Huggingface [SafeTensors](https://github.com/huggingface/safetensors) model and tokenizer format
* Support for F32, F16, BF16 types
* Support for Q8, Q4 model quantization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ public class ActivationFunction {

public enum Type {
SILU,
GELU
GELU,
TANH
}

public static float eval(Type t, float x) {
return switch (t) {
case SILU -> (float) (x * (1.0f / (1.0f + exp(-x))));
case GELU -> (float) (0.5 * x * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3)))));
case TANH -> (float) Math.tanh(x);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ public static void l1normalize(float[] x) {
x[i] /= sum;
}

public static void l2normalize(AbstractTensor x) {
float sum = 0.0f;
for (int i = 0; i < x.shape().last(); i++) {
float v = x.get(0, i);
sum += v * v;
}
double magnitude = Math.sqrt(sum);
for (int i = 0; i < x.shape().last(); i++)
x.set((float)(x.get(0, i) / magnitude), 0, i);
}

public static void l2normalize(float[] x) {
float sum = 0.0f;
for (int i = 0; i < x.length; i++)
Expand Down
151 changes: 118 additions & 33 deletions jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import static com.github.tjake.jlama.util.DebugSupport.debug;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.model.functions.*;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
Expand All @@ -30,15 +29,14 @@
import com.github.tjake.jlama.safetensors.prompt.Tool;
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.tensor.*;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.nio.FloatBuffer;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
Expand All @@ -53,31 +51,28 @@ public abstract class AbstractModel implements Generator {
private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class);

public enum InferenceType {
INPUT_TO_EMBEDDING(true, false, false),
OUTPUT_TO_TOKEN(false, true, false),
FORWARD_PASS(true, false, true),
FULL_GENERATION(true, true, true);
//Used for distributed inference
INPUT_TO_EMBEDDING(true, false, false, false, false),
OUTPUT_TO_TOKEN(false, false, true, false, false),
FORWARD_PASS(true, true, false, false, false),

//Used for different types of inference
FULL_GENERATION(true, true, true, false,false),
FULL_CLASSIFICATION(true, true, false, true, true),
FULL_EMBEDDING(true, true, false, false, true);

final boolean isInput;
final boolean isOutput;
final boolean isClassify;
final boolean isFwdPass;
final boolean isPooling;

InferenceType(boolean isInput, boolean isOutput, boolean isFwdPass) {
InferenceType(boolean isInput, boolean isFwdPass, boolean isOutput, boolean isClassify, boolean isPooling) {
this.isInput = isInput;
this.isOutput = isOutput;
this.isFwdPass = isFwdPass;
}

public boolean isEmbedding() {
return isInput;
}

public boolean isOutput() {
return isOutput;
}

public boolean isFwdPass() {
return isFwdPass;
this.isClassify = isClassify;
this.isPooling = isPooling;
}
}

Expand All @@ -91,6 +86,8 @@ public boolean isFwdPass() {
protected final Optional<DType> modelQType;
protected EmbedInput embedInput;
protected SampleOutput sampleOutput;
protected ClassifyOutput classifyOutput;
protected Optional<PoolingLayer> poolingLayer;
protected TransformerBlock[] transformerBlocks;
protected KvBufferCache kvBufferCache;

Expand Down Expand Up @@ -143,6 +140,8 @@ protected AbstractModel(
this.embedInput = inferenceType.isInput ? loadInputWeights() : null;
this.transformerBlocks = inferenceType.isFwdPass ? loadTransformerBlockWeights() : null;
this.sampleOutput = inferenceType.isOutput ? loadOutputWeights() : null;
this.classifyOutput = inferenceType.isClassify ? loadClassifierWeights() : null;
this.poolingLayer = inferenceType.isPooling ? Optional.ofNullable(loadPoolingWeights()) : Optional.empty();
}

protected abstract EmbedInput loadInputWeights();
Expand All @@ -151,6 +150,14 @@ protected AbstractModel(

protected abstract SampleOutput loadOutputWeights();

protected ClassifyOutput loadClassifierWeights() {
throw new UnsupportedOperationException("Classification not supported by this model");
}

protected PoolingLayer loadPoolingWeights() {
return null;
}

public abstract ModelSupport.ModelType getModelType();

public InferenceType getInferenceType() {
Expand All @@ -169,6 +176,10 @@ public Tokenizer getTokenizer() {
return tokenizer;
}

public WeightLoader getWeights() {
return weights;
}

public Optional<PromptSupport> promptSupport() {
return tokenizer.promptSupport();
}
Expand Down Expand Up @@ -253,27 +264,101 @@ public AbstractTensor batchForward(
}

@Override
public float[] embed(String input) {
public float[] embed(String input, PoolingType poolingType) {
int[] encoded = Arrays.stream(tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();

Preconditions.checkArgument(encoded.length < c.contextLength);
float[] outputEmbedding = new float[c.embeddingLength];

try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(UUID.randomUUID())) {
int promptLength = encoded.length;
float avgp = 1.0f / promptLength;

AbstractTensor r = batchForward(encoded, 0, kvmem);
for (int i = 0; i < promptLength; i++) {
AbstractTensor output = r.slice(i);
try (AbstractTensor r = batchForward(encoded, 0, kvmem)) {
if (poolingType == PoolingType.MODEL) {
if (poolingLayer.isPresent()) {

// Get the last value should represent the sum of the prompt (due to attention)
AbstractTensor output = r.slice(promptLength - 1);
AbstractTensor pooled = makeDenseTensor(1, c.embeddingLength);

// Pooling
TensorOperationsProvider.get()
.batchDotProduct(
pooled,
output,
poolingLayer.get().getPoolingWeights(),
0,
0,
c.embeddingLength);

poolingLayer.get().getPoolingBias().ifPresent(bias -> {
TensorOperationsProvider.get().accumulate(pooled, bias, 0, c.embeddingLength);
});

VectorMath.pfor(0, c.embeddingLength, i -> {
//BERT seems to use tanh for pooling rather than gelu
outputEmbedding[i] = ActivationFunction.eval(ActivationFunction.Type.TANH, pooled.get(0, i));
});

return outputEmbedding;
}

throw new UnsupportedOperationException("Pooling layer not found");
}

// Average Pooling
for (int ii = 0; ii < c.embeddingLength; ii++)
outputEmbedding[ii] += output.get(0, ii) * avgp;
// No pooling layer, so we just pool manually embeddings
for (int i = 0; i < promptLength; i++) {
AbstractTensor output = r.slice(i);
// Pooling
for (int ii = 0; ii < c.embeddingLength; ii++) {
switch (poolingType) {
case AVG:
outputEmbedding[ii] += output.get(0, ii) * avgp;
break;
case MAX:
outputEmbedding[ii] = Math.max(outputEmbedding[ii], output.get(0, ii));
break;
case SUM:
outputEmbedding[ii] += output.get(0, ii);
break;
}
}
}
}
r.close();
VectorMath.l2normalize(outputEmbedding);
return outputEmbedding;
}
}

@Override
public Map<String, Float> classify(String input, PoolingType poolingType) {
if (!c.isClassifier() || classifyOutput == null) {
throw new UnsupportedOperationException("Classification not supported by this model");
}
return outputEmbedding;

float[] embedding = embed(input, poolingType);
FloatBufferTensor b = new FloatBufferTensor(FloatBuffer.wrap(embedding), TensorShape.of(embedding.length), false);

int classes = classifyOutput.getClassificationWeights().shape().first();
AbstractTensor scores = makeDenseTensor(classes);

TensorOperationsProvider.get().batchDotProduct(scores, b, classifyOutput.getClassificationWeights(), 0, 0, c.embeddingLength);

classifyOutput.getClassificationBias().ifPresent(bias -> {
TensorOperationsProvider.get().accumulate(scores, bias, 0, classes);
});

VectorMath.softMax(scores, 0, classes);
Map<String, Float> result = new HashMap<>();
for (int i = 0; i < classes; i++) {
String label = c.classifcationLabels.get().inverse().get(i);
Float score = scores.get(0, i);

result.put(label, score);
}

return result;
}

public int sample(AbstractTensor output, float temperature, float uniformSample, AbstractTensor logits) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ public static AbstractModel loadModel(File model, DType workingMemoryType, DType

/** Shortcut for loading a model for embeddings */
public static AbstractModel loadEmbeddingModel(File model, DType workingMemoryType, DType workingQuantizationType) {
return loadModel(AbstractModel.InferenceType.FORWARD_PASS, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty());
return loadModel(AbstractModel.InferenceType.FULL_EMBEDDING, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty());
}

/** Shortcut for loading a model for embeddings */
public static AbstractModel loadClassifierModel(File model, DType workingMemoryType, DType workingQuantizationType) {
return loadModel(AbstractModel.InferenceType.FULL_CLASSIFICATION, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty());
}

public static AbstractModel loadModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;
import java.util.List;
import java.util.Map;

public class BertConfig extends Config {
@JsonCreator
Expand All @@ -31,7 +32,10 @@ public BertConfig(
@JsonProperty("num_hidden_layers") int numberOfLayers,
@JsonProperty("layer_norm_eps") float layerNormEps,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("vocab_size") int vocabularySize
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("label2id") Map<String, Integer> classificationLabels,
@JsonProperty("sep_token") Integer sepToken,
@JsonProperty("cls_token") Integer clsToken
) {
super(
contextLength,
Expand All @@ -42,11 +46,12 @@ public BertConfig(
numberOfLayers,
layerNormEps,
vocabularySize,
0,
List.of(0),
sepToken == null ? 0 : sepToken,
clsToken == null ? List.of(0) : List.of(clsToken),
activationFunction,
null,
null
null,
classificationLabels
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.PoolingLayer;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
Expand Down Expand Up @@ -131,7 +133,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
prefix = b;
MLPBlock mlpBlock = new MLPBlock(
this,
ActivationFunction.Type.GELU,
c.activationFunction,
loadWeight(prefix + "intermediate.dense.bias"),
loadWeight(prefix + "intermediate.dense.weight"),
loadWeight(prefix + "output.dense.bias"),
Expand Down Expand Up @@ -159,4 +161,43 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
protected SampleOutput loadOutputWeights() {
throw new UnsupportedOperationException();
}

@Override
protected PoolingLayer loadPoolingWeights() {

final AbstractTensor poolerDenseWeight = loadWeight("pooler.dense.weight");
final AbstractTensor poolerDenseBias = loadWeight("pooler.dense.bias");

return new PoolingLayer() {
public AbstractTensor getPoolingWeights() {
return poolerDenseWeight;
}

public Optional<AbstractTensor> getPoolingBias() {
return Optional.of(poolerDenseBias);
}
};
}

@Override
protected ClassifyOutput loadClassifierWeights() {
if (c.isClassifier()) {
final AbstractTensor classifierWeight = loadWeight("classifier.weight");
final AbstractTensor classifierBias = loadWeight("classifier.bias");

return new ClassifyOutput() {
@Override
public AbstractTensor getClassificationWeights() {
return classifierWeight;
}

@Override
public Optional<AbstractTensor> getClassificationBias() {
return Optional.of(classifierBias);
}
};
} else {
throw new UnsupportedOperationException("Classification not supported by this model");
}
}
}
Loading

0 comments on commit 2c3f080

Please sign in to comment.