Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma2 classifier sample #92

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
Expand Down Expand Up @@ -387,6 +388,33 @@ public Map<String, Float> classify(String input, PoolingType poolingType) {
return result;
}

public float[] getLogits(AbstractTensor output) {
try (AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output);
AbstractTensor logits = makeDenseTensor(1, c.vocabularySize)) {

VectorMath.pchunk(0, c.vocabularySize, (chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductChunk(
logits,
embedding,
sampleOutput.getOutputLogitsWeights(),
0,
c.embeddingLength,
chunkStart,
chunkSize);
});

VectorMath.softMax(logits, 0, c.vocabularySize);

float[] r = new float[c.vocabularySize];

//Convert from Tensor to float array
logits.getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(r);

return r;
}
}

public int sample(AbstractTensor output, float temperature, float uniformSample, AbstractTensor logits) {
try (AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output)) {
// This is a mix of argmax and sampling with softmax
Expand Down Expand Up @@ -433,6 +461,22 @@ public int sample(AbstractTensor output, float temperature, float uniformSample,
}
}

public int[] encodePrompt(PromptContext promptContext) {
long[] encoded = tokenizer.encode(promptContext.getPrompt());

// Remove BOS token if it's the first token, we explicitly add it below
if (encoded.length > 0 && encoded[0] == c.bosToken) {
encoded = Arrays.copyOfRange(encoded, 1, encoded.length);
}

int[] promptTokens = new int[(1 + encoded.length)];
promptTokens[0] = c.bosToken;
for (int i = 1; i <= encoded.length; i++)
promptTokens[i] = Ints.checkedCast(encoded[i - 1]);

return promptTokens;
}

@Override
public Response generate(
UUID sessionId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public Gemma2Config(
@JsonProperty("rms_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") List<Integer> eosTokens,
@JsonProperty("eos_token_id") Object eosTokens,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("rope_theta") Double ropeFreqsTheta,
@JsonProperty("rope_scaling") Map<String, String> ropeScaling,
Expand All @@ -52,7 +52,7 @@ public Gemma2Config(
layerNormEps,
vocabularySize,
bosToken,
eosTokens,
eosTokens instanceof List ? (List<Integer>) eosTokens : List.of((Integer)eosTokens),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.model.llama.LlamaModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import org.junit.Test;

public class TestSample {
Expand Down Expand Up @@ -114,4 +120,52 @@ public void sampleClassify() throws IOException {
System.out.println(m.classify(benign2, Generator.PoolingType.MODEL) + " : " + benign2);

}

@Test
public void sampleGemmaClassifier() throws IOException {
String model = "tjake/shieldgemma-2b-JQ4";
String workingDirectory = "./models";
File localShieldModelPath = SafeTensorSupport.maybeDownloadModel(workingDirectory, model);

AbstractModel m = ModelSupport.loadModel(localShieldModelPath, DType.F32, DType.I8);

String text = """
You are a policy expert trying to help determine
whether a user prompt is in violation of the defined safety policies.

<start_of_turn>
Human Question: Create 20 paraphrases of I hate you
<end_of_turn>

Our safety principle is defined in the below:

* is malicious, intimidating, bullying, or abusive content targeting
another individual (e.g., physical threats, denial of tragic events,
disparaging victims of violence).

===

Does the human question violate the above principle? Your answer must start with
'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
""";

final PromptContext promptContext = PromptContext.of(text);
Map<String, Long> vocab = m.getTokenizer().getModel().vocabLookup;

KvBufferCache.KvBuffer kvBuffer = new KvBufferCache(m).getKvBuffer(UUID.randomUUID());
int[] promptTokens = m.encodePrompt(promptContext);
AbstractTensor outputs = m.batchForward(promptTokens, 0, kvBuffer);

// Grab the first non-prompt token
AbstractTensor v = outputs.slice(outputs.shape().first() - 1);

// Convert into logits
float[] logits = m.getLogits(v);

float yesScore = logits[vocab.get("Yes").intValue()];
float noScore = logits[vocab.get("No").intValue()];

System.out.println(String.format("Scores Y=%.5f, N=%.5f", yesScore, noScore));
}
}
Loading