Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] LLaVA 1.6 LoRA fine-tuning example #605

Open
tctrautman opened this issue Mar 21, 2024 · 16 comments
Open

[Feature Request] LLaVA 1.6 LoRA fine-tuning example #605

tctrautman opened this issue Mar 21, 2024 · 16 comments

Comments

@tctrautman
Copy link

Building on the amazing work by @mzbac and @nkasmanoff in #461, I'd really love an example of how LLaVA 1.6 (aka llava next) can be fine-tuned with a LoRA.

I might be able to make progress on this myself, but it'll take me some time. Any help or thoughts on how to best approach this would be appreciated. (Especially from @mzbac and/or @nkasmanoff.)

@mzbac
Copy link
Contributor

mzbac commented Mar 22, 2024

I was waiting for the Llava 1.6 support to be merged into the transformer so we can have consistent model weight naming conventions. I haven't looked at the details of it, but the most challenging part would be in https://github.com/ml-explore/mlx-examples/blob/main/llava/llava.py#L104. Do you mind starting a draft PR for it? I am happy to help in any way if you get stuck on the implementation.

@tctrautman
Copy link
Author

@mzbac sounds good! I'll take a look over the weekend and see how it goes -- a lot of this is going to be pretty new for me.

Quick question as I'm getting started: do you expect it to be complex to support both the Mistral and Vicuna models?

I ask because I see the line below where we've explicitly said we only support the Llama language model.

if self.model_type != "llama":

I'm personally most interested in getting the Mistral 7B version working, but I'd be open to working with Vicuna if you expect that to be a better place to get started.

@mzbac
Copy link
Contributor

mzbac commented Mar 22, 2024

The Mistral is using the same Llama architecture. If the merge image feature with input ids functions properly, the language model part should just work out of box.
You can take a look at the model weight here -> https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf
Also, check out this ASCII diagram I created for the llava 1.5 model: https://gist.github.com/mzbac/00ebe60bb36fa4d8f65509f8e47350d5

@jrp2014
Copy link

jrp2014 commented Mar 22, 2024

I couldn't get 1.6 to work either as it said that only llama and not mistral models were supported.

On a further test, trying llava-hf/llava-v1.6-34b-hf, I got:

preprocessor_config.json: 100%|█████████████████████████████████████████████████████| 754/754 [00:00<00:00, 958kB/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████| 1.86k/1.86k [00:00<00:00, 4.99MB/s]
tokenizer.model: 100%|█████████████████████████████████████████████████████████| 1.03M/1.03M [00:00<00:00, 5.56MB/s]
added_tokens.json: 100%|█████████████████████████████████████████████████████████| 23.0/23.0 [00:00<00:00, 81.9kB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████| 748/748 [00:00<00:00, 3.56MB/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
config.json: 100%|█████████████████████████████████████████████████████████████| 1.41k/1.41k [00:00<00:00, 6.51MB/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 3, in <module>
    processor, model = load_model("llava-hf/llava-v1.6-34b-hf")
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 83, in load_model
    processor = AutoProcessor.from_pretrained(model_path)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/mlx/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py", line 340, in from_pretrained
    raise ValueError(
ValueError: Unrecognized processing class in llava-hf/llava-v1.6-34b-hf. Can't instantiate a processor, a tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains the files of at least one of those processing classes.

... updating the transformers library and installing protobuf seems to be more promising ...

@tctrautman
Copy link
Author

@jrp2014 I ran into a similar processor error before I realized I hadn't updated HF transformers -- updating to 4.39.1 resolved it on my end.

@jrp2014
Copy link

jrp2014 commented Mar 23, 2024

I now get

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Fetching 23 files: 100%|█████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 28676.87it/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 3, in <module>
    processor, model = load_model("llava-hf/llava-v1.6-34b-hf")
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 84, in load_model
    model = LlavaModel.from_pretrained(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 178, in from_pretrained
    model.load_weights(list(weights.items()))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/mlx/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 203, in load_weights
    raise ValueError(f"Received parameters not in model: {extras}.")
ValueError: Received parameters not in model: image_newline.
python mytest.py  1.41s user 3.73s system 171% cpu 3.002 total

I'll have a look to see whether there are any obvious fixes.

@nkasmanoff
Copy link
Contributor

Hey great initiative! Excited to try out these new and improved LlaVA models. Agreed that using the standard transformers format makes it easier.

I've been thinking a bit about the fine tuning aspect, and can share some code snippets I have so far which might help the discussion.

There's a ton of variety in fine-tuning the vision models, but to me easiest direction would be starting with solely fine-tuning the language model. If you wanted to do just that I am following the code from https://github.com/ml-explore/mlx-examples/tree/main/llms, where at least to start, you can attach LoRA layers to the llm. Here's how I have that for now.

model.vision_tower.freeze()
model.multi_modal_projector.freeze()
model.language_model.freeze()

lora_layers = 32
lora_parameters = {
    "keys": ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.out_proj"],
    "rank": 64,
    "alpha": 16.0,
    "scale": 10.0,
    "dropout": 0.0,
}

linear_to_lora_layers(model.language_model.model, lora_layers, lora_parameters)

print_trainable_parameters(model.vision_tower)
print_trainable_parameters(model.multi_modal_projector)
print_trainable_parameters(model.language_model)

# Trainable parameters: 0.000% (0.000M/303.506M)
# Trainable parameters: 0.000% (0.000M/20.980M)
# Trainable parameters: 0.747% (50.332M/6738.940M)

I'm still not sure what the best approach would be for the language modeling piece. Maybe you mask everything except the answer? The fact that the shape of the inputs change once you insert all the image embeddings makes this tough and I haven't found any proper demos at least so far. The caveat being if you want your vision tower to get better, maybe you fine-tune that separately or already start with a with a CLIP-like model already fine-tuned to a different domain?

@jrp2014
Copy link

jrp2014 commented Mar 23, 2024

So for 1.6 it seems that we need to produce an image_newline token at the end of each image row. (I have no idea whether it is each actual row or what the role of patches is …).

@tctrautman
Copy link
Author

tctrautman commented Mar 23, 2024 via email

@jrp2014
Copy link

jrp2014 commented Mar 23, 2024

Thanks. If I cheat, and add a False to the load model, to let it ignore the model parameter (image_newline) mismatch, I get:

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Fetching 23 files: 100%|███████████████████████████████████████████████| 23/23 [00:00<00:00, 319433.75it/s]
Traceback (most recent call last):
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/mytest.py", line 13, in <module>
    reply = generate_text(
            ^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/generate.py", line 97, in generate_text
    logits, cache = model(input_ids, pixel_values)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 135, in __call__
    input_embddings = self.get_input_embeddings(input_ids, pixel_values)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrp/Documents/AI/mlx/mlx-examples/llava/llava.py", line 79, in get_input_embeddings
    pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Transpose axes don't match array dimensions.

@mzbac
Copy link
Contributor

mzbac commented Mar 24, 2024

We need to preprocess the image feature before merging it with the input and image features. You can find more information at https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L514-L553.

I personally feel it would be better to copy the existing llava 1.5 and create a new example for 1.6, as there have been quite a few changes introduced in llava 1.6.

@nkasmanoff
Copy link
Contributor

I’m unfortunately busy and won’t be able to take a closer look again until later today, but re: fine-tuning demos, this video might be helpful: https://www.youtube.com/watch?v=eIziN2QUt8U

It did, in so that it confirmed the only way to do this it seems is to train on what comes in the answer, or after the image. With that in mind I have a quick implementation working here, which finetunes on a dataset where every token after the one is trained.

Since LlaVA works with in arbitrary places this would be a cool enhancement, but not feasible in how it's set up from what I could tell.

@tctrautman
Copy link
Author

Unfortunately, it's taking me longer than I had hoped to get up to speed -- I still might be able to get a draft PR up but it won't be for a while longer.

@ahmetkca
Copy link

I am really interested in this work. However, I don't know where to start though. Have you got time to advance by any chance @tctrautman? I know you are doing this in your spare time and we are all grateful for that. Please let me know, if I can be helpful for anything.

@tctrautman
Copy link
Author

@ahmetkca Unfortunately I haven't had time to make any progress on this, and I'm honestly not sure when I will 😞

@Blaizzy
Copy link
Contributor

Blaizzy commented Jun 22, 2024

PR to support LLaVA v1.6 will be merged to mlx-vlm early tomorrow :)

And the trainer (FFT and LoRA) should follow soon.

Blaizzy/mlx-vlm#43

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants