-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] LLaVA 1.6 LoRA fine-tuning example #605
Comments
I was waiting for the Llava 1.6 support to be merged into the transformer so we can have consistent model weight naming conventions. I haven't looked at the details of it, but the most challenging part would be in https://github.com/ml-explore/mlx-examples/blob/main/llava/llava.py#L104. Do you mind starting a draft PR for it? I am happy to help in any way if you get stuck on the implementation. |
@mzbac sounds good! I'll take a look over the weekend and see how it goes -- a lot of this is going to be pretty new for me. Quick question as I'm getting started: do you expect it to be complex to support both the Mistral and Vicuna models? I ask because I see the line below where we've explicitly said we only support the Llama language model. mlx-examples/llava/language.py Line 210 in fbed720
I'm personally most interested in getting the Mistral 7B version working, but I'd be open to working with Vicuna if you expect that to be a better place to get started. |
The Mistral is using the same Llama architecture. If the merge image feature with input ids functions properly, the language model part should just work out of box. |
I couldn't get 1.6 to work either as it said that only llama and not mistral models were supported. On a further test, trying llava-hf/llava-v1.6-34b-hf, I got:
... updating the transformers library and installing protobuf seems to be more promising ... |
@jrp2014 I ran into a similar processor error before I realized I hadn't updated HF transformers -- updating to |
I now get
I'll have a look to see whether there are any obvious fixes. |
Hey great initiative! Excited to try out these new and improved LlaVA models. Agreed that using the standard transformers format makes it easier. I've been thinking a bit about the fine tuning aspect, and can share some code snippets I have so far which might help the discussion. There's a ton of variety in fine-tuning the vision models, but to me easiest direction would be starting with solely fine-tuning the language model. If you wanted to do just that I am following the code from https://github.com/ml-explore/mlx-examples/tree/main/llms, where at least to start, you can attach LoRA layers to the llm. Here's how I have that for now.
I'm still not sure what the best approach would be for the language modeling piece. Maybe you mask everything except the answer? The fact that the shape of the inputs change once you insert all the image embeddings makes this tough and I haven't found any proper demos at least so far. The caveat being if you want your vision tower to get better, maybe you fine-tune that separately or already start with a with a CLIP-like model already fine-tuned to a different domain? |
So for 1.6 it seems that we need to produce an image_newline token at the end of each image row. (I have no idea whether it is each actual row or what the role of patches is …). |
I’m unfortunately busy and won’t be able to take a closer look again until later today, but re: fine-tuning demos, this video might be helpful:
https://www.youtube.com/watch?v=eIziN2QUt8U
|
Thanks. If I cheat, and add a False to the load model, to let it ignore the model parameter (image_newline) mismatch, I get:
|
We need to preprocess the image feature before merging it with the input and image features. You can find more information at https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L514-L553. I personally feel it would be better to copy the existing llava 1.5 and create a new example for 1.6, as there have been quite a few changes introduced in llava 1.6. |
It did, in so that it confirmed the only way to do this it seems is to train on what comes in the answer, or after the image. With that in mind I have a quick implementation working here, which finetunes on a dataset where every token after the one is trained. Since LlaVA works with in arbitrary places this would be a cool enhancement, but not feasible in how it's set up from what I could tell. |
Unfortunately, it's taking me longer than I had hoped to get up to speed -- I still might be able to get a draft PR up but it won't be for a while longer. |
I am really interested in this work. However, I don't know where to start though. Have you got time to advance by any chance @tctrautman? I know you are doing this in your spare time and we are all grateful for that. Please let me know, if I can be helpful for anything. |
@ahmetkca Unfortunately I haven't had time to make any progress on this, and I'm honestly not sure when I will 😞 |
PR to support LLaVA v1.6 will be merged to mlx-vlm early tomorrow :) And the trainer (FFT and LoRA) should follow soon. |
Building on the amazing work by @mzbac and @nkasmanoff in #461, I'd really love an example of how LLaVA 1.6 (aka llava next) can be fine-tuned with a LoRA.
I might be able to make progress on this myself, but it'll take me some time. Any help or thoughts on how to best approach this would be appreciated. (Especially from @mzbac and/or @nkasmanoff.)
The text was updated successfully, but these errors were encountered: