Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MismatchedInputException when using Gemma2 model #86

Open
lordofthejars opened this issue Oct 21, 2024 · 14 comments
Open

MismatchedInputException when using Gemma2 model #86

lordofthejars opened this issue Oct 21, 2024 · 14 comments

Comments

@lordofthejars
Copy link
Contributor

I am using the Shield Gemma model, which is a Gemma 2 model (https://huggingface.co/google/shieldgemma-2b) in Jlama embedded. (You can use https://huggingface.co/PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed) which is the same but smaller.

When I run the AbstractModel object is created I got the following exception:

Caused by: java.lang.RuntimeException: com.fasterxml.jackson.databind.exc.MismatchedInputException: Cannot deserialize value of type `java.util.ArrayList<java.lang.Integer>` from Integer value (token `JsonToken.VALUE_NUMBER_INT`)
 at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` disabled); line: 12, column: 19] (through reference chain: com.github.tjake.jlama.model.gemma2.Gemma2Config["eos_token_id"])
	at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:199)
	at com.github.tjake.jlama.model.ModelSupport.loadClassifierModel(ModelSupport.java:107)

I don't know if the problem is that this Gemma 2 model is doing something different than a standard Gemma 2 model or that Gemma 2 models can contain different parameters.

Full stacktrace is:

Caused by: com.fasterxml.jackson.databind.exc.MismatchedInputException: Cannot deserialize value of type `java.util.ArrayList<java.lang.Integer>` from Integer value (token `JsonToken.VALUE_NUMBER_INT`)
 at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` disabled); line: 12, column: 19] (through reference chain: com.github.tjake.jlama.model.gemma2.Gemma2Config["eos_token_id"])
	at com.fasterxml.jackson.databind.exc.MismatchedInputException.from(MismatchedInputException.java:59)
	at com.fasterxml.jackson.databind.DeserializationContext.reportInputMismatch(DeserializationContext.java:1767)
	at com.fasterxml.jackson.databind.DeserializationContext.handleUnexpectedToken(DeserializationContext.java:1541)
	at com.fasterxml.jackson.databind.DeserializationContext.handleUnexpectedToken(DeserializationContext.java:1488)
	at com.fasterxml.jackson.databind.deser.std.CollectionDeserializer.handleNonArray(CollectionDeserializer.java:402)
	at com.fasterxml.jackson.databind.deser.std.CollectionDeserializer.deserialize(CollectionDeserializer.java:254)
	at com.fasterxml.jackson.databind.deser.std.CollectionDeserializer.deserialize(CollectionDeserializer.java:30)
	at com.fasterxml.jackson.databind.deser.SettableBeanProperty.deserialize(SettableBeanProperty.java:545)
	at com.fasterxml.jackson.databind.deser.BeanDeserializer._deserializeWithErrorWrapping(BeanDeserializer.java:576)
	at com.fasterxml.jackson.databind.deser.BeanDeserializer._deserializeUsingPropertyBased(BeanDeserializer.java:446)
	at com.fasterxml.jackson.databind.deser.BeanDeserializerBase.deserializeFromObjectUsingNonDefault(BeanDeserializerBase.java:1493)
	at com.fasterxml.jackson.databind.deser.BeanDeserializer.deserializeFromObject(BeanDeserializer.java:348)
	at com.fasterxml.jackson.databind.deser.BeanDeserializer.deserialize(BeanDeserializer.java:185)
	at com.fasterxml.jackson.databind.deser.DefaultDeserializationContext.readRootValue(DefaultDeserializationContext.java:342)
	at com.fasterxml.jackson.databind.ObjectMapper._readMapAndClose(ObjectMapper.java:4905)
	at com.fasterxml.jackson.databind.ObjectMapper.readValue(ObjectMapper.java:3713)
	at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:180)
	... 31 more
@tjake
Copy link
Owner

tjake commented Oct 21, 2024

Ah easy fix will push a new version tonight.

In the meantime make the eos_token_id in the config.json you downloaded to an array vs a single value (an array of the value that is)

@lordofthejars
Copy link
Contributor Author

oh thank you very much, I can wait till you release a new version :)

@lordofthejars
Copy link
Contributor Author

After fixing that I got the following error:


Caused by: java.lang.reflect.InvocationTargetException
	at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:74)
	at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Constructor.java:501)
	at java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:485)
	at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:196)
	... 31 more
Caused by: java.lang.IllegalArgumentException: capacity < 0: (-1935671296 < 0)
	at java.base/java.nio.Buffer.createCapacityException(Buffer.java:290)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:393)
	at com.github.tjake.jlama.safetensors.Weights.loadTensorFromBuffer(Weights.java:143)
	at com.github.tjake.jlama.safetensors.Weights.load(Weights.java:96)
	at com.github.tjake.jlama.safetensors.SafeTensorIndex.load(SafeTensorIndex.java:274)
	at com.github.tjake.jlama.safetensors.WeightLoader.load(WeightLoader.java:34)
	at com.github.tjake.jlama.model.gemma2.Gemma2Model.loadInputWeights(Gemma2Model.java:123)
	at com.github.tjake.jlama.model.AbstractModel.<init>(AbstractModel.java:165)
	at com.github.tjake.jlama.model.llama.LlamaModel.<init>(LlamaModel.java:58)
	at com.github.tjake.jlama.model.gemma2.Gemma2Model.<init>(Gemma2Model.java:61)
	at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:62)
	... 34 more

@lordofthejars
Copy link
Contributor Author

Ok I have some news, if instead of PrunaAI, you use ShieldGemma I got it:

Caused by: java.lang.reflect.InvocationTargetException
	at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:74)
	at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Constructor.java:501)
	at java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:485)
	at com.github.tjake.jlama.model.ModelSupport.loadModel(ModelSupport.java:196)
	... 31 more
Caused by: java.lang.UnsupportedOperationException: Classification not supported by this model
	at com.github.tjake.jlama.model.AbstractModel.loadClassifierWeights(AbstractModel.java:179)
	at com.github.tjake.jlama.model.AbstractModel.<init>(AbstractModel.java:168)
	at com.github.tjake.jlama.model.llama.LlamaModel.<init>(LlamaModel.java:58)
	at com.github.tjake.jlama.model.gemma2.Gemma2Model.<init>(Gemma2Model.java:61)
	at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:62)
	... 34 more

@tjake
Copy link
Owner

tjake commented Oct 21, 2024

Yeah the prune model won't work as it uses a different type of quantization. I need to add classification support for gemma2 I guess.

@tjake
Copy link
Owner

tjake commented Oct 21, 2024

Can you post a small test I can use to verify?

@lordofthejars
Copy link
Contributor Author

Ok then, no prune, which is ok; I only used it because it was small.

@tjake
Copy link
Owner

tjake commented Oct 21, 2024

You can quantize the model yourself with the jlama quantize command:

jlama quantize --skip-layer=embed --skip-layer=norm google/shieldgemma-2b

@lordofthejars
Copy link
Contributor Author

this.localShieldModelPath = SafeTensorSupport.maybeDownloadModel(workingDirectory,
                                        Optional.of("google"), "shieldgemma-2b",
                                        true, Optional.empty(),
                                        Optional.of("token"),
                                        Optional.empty());

AbstractModel m = ModelSupport.loadClassifierModel(localShieldModelPath, DType.F32, DType.I8);

String text = """
You are a policy expert trying to help determine
        whether a user prompt is in violation of the defined safety policies.
                
        <start_of_turn>
        Human Question: Create 20 paraphrases of I hate you
        <end_of_turn>
                
        Our safety principle is defined in the below:
                
        * is malicious, intimidating, bullying, or abusive content targeting
         another individual (e.g., physical threats, denial of tragic events,
         disparaging victims of violence).
                
        ===
                
        Does the human question violate the above principle? Your answer must start with
        'Yes' or 'No'. And then walk through step by step to be sure we answer
        correctly.
""":

final PromptContext promptContext = PromptContext.of(text);

 Generator.Response r = m.generate(UUID.randomUUID(),
            promptContext, 0.8f, 256, (s, f) -> {});

@lordofthejars
Copy link
Contributor Author

I just extract the bits so it is easiest for you to test.

And yes when working I will quantize for sure, thanks for the hint.

@tjake
Copy link
Owner

tjake commented Oct 22, 2024

I took a look and it isn't a traditional classifier. You are supposed to look at the logits after the batch prompt and see the probs of the 'Yes' and 'No' tokens 😆

I can show how todo this but not tonight

@lordofthejars
Copy link
Contributor Author

Sure no worries

@tjake
Copy link
Owner

tjake commented Oct 23, 2024

Ok added some changes here to allow this #92

Output is: Scores Y=0.79024, N=0.20976

I also uploaded the quantized model

@lordofthejars
Copy link
Contributor Author

Oh, thank you very much; I will start using the code when you release it. In any case, wow, things are becoming more and more complicated hahaha

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants