-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MismatchedInputException when using Gemma2 model #86
Comments
Ah easy fix will push a new version tonight. In the meantime make the eos_token_id in the config.json you downloaded to an array vs a single value (an array of the value that is) |
oh thank you very much, I can wait till you release a new version :) |
After fixing that I got the following error:
|
Ok I have some news, if instead of PrunaAI, you use ShieldGemma I got it:
|
Yeah the prune model won't work as it uses a different type of quantization. I need to add classification support for gemma2 I guess. |
Can you post a small test I can use to verify? |
Ok then, no prune, which is ok; I only used it because it was small. |
You can quantize the model yourself with the jlama quantize command:
|
this.localShieldModelPath = SafeTensorSupport.maybeDownloadModel(workingDirectory,
Optional.of("google"), "shieldgemma-2b",
true, Optional.empty(),
Optional.of("token"),
Optional.empty());
AbstractModel m = ModelSupport.loadClassifierModel(localShieldModelPath, DType.F32, DType.I8);
String text = """
You are a policy expert trying to help determine
whether a user prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: Create 20 paraphrases of I hate you
<end_of_turn>
Our safety principle is defined in the below:
* is malicious, intimidating, bullying, or abusive content targeting
another individual (e.g., physical threats, denial of tragic events,
disparaging victims of violence).
===
Does the human question violate the above principle? Your answer must start with
'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
""":
final PromptContext promptContext = PromptContext.of(text);
Generator.Response r = m.generate(UUID.randomUUID(),
promptContext, 0.8f, 256, (s, f) -> {}); |
I just extract the bits so it is easiest for you to test. And yes when working I will quantize for sure, thanks for the hint. |
I took a look and it isn't a traditional classifier. You are supposed to look at the logits after the batch prompt and see the probs of the 'Yes' and 'No' tokens 😆 I can show how todo this but not tonight |
Sure no worries |
Ok added some changes here to allow this #92 Output is: I also uploaded the quantized model |
Oh, thank you very much; I will start using the code when you release it. In any case, wow, things are becoming more and more complicated hahaha |
I am using the Shield Gemma model, which is a Gemma 2 model (https://huggingface.co/google/shieldgemma-2b) in Jlama embedded. (You can use https://huggingface.co/PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed) which is the same but smaller.
When I run the
AbstractModel
object is created I got the following exception:I don't know if the problem is that this Gemma 2 model is doing something different than a standard Gemma 2 model or that Gemma 2 models can contain different parameters.
Full stacktrace is:
The text was updated successfully, but these errors were encountered: