Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Llama-3.2-vision & Resize image #83

Merged
merged 17 commits into from
Oct 17, 2024
1 change: 1 addition & 0 deletions mlx_vlm/LORA.MD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Idefics 2
- Deepseek-VL
- Paligemma
- Mllama (Llama-3.2-vision)

## Coming Soon
- LLaVA-Next
Expand Down
15 changes: 15 additions & 0 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def parse_arguments():
default=DEFAULT_IMAGE,
help="URL or path of the image to process.",
)
parser.add_argument(
"--resize-shape",
type=int,
nargs=2,
default=None,
help="Resize shape for the image.",
)
parser.add_argument(
"--prompt",
type=str,
Expand Down Expand Up @@ -78,6 +85,13 @@ def main():

prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image))

kwargs = {}
if args.resize_shape is not None:
assert (
len(args.resize_shape) == 2
), "Resize shape must be a tuple of two integers"
kwargs["resize_shape"] = args.resize_shape

output = generate(
model,
processor,
Expand All @@ -87,6 +101,7 @@ def main():
args.temp,
args.max_tokens,
args.verbose,
**kwargs,
)
if not args.verbose:
print(output)
Expand Down
8 changes: 8 additions & 0 deletions mlx_vlm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def process_data(examples):
config,
processor,
image_processor=image_processor,
image_resize_shape=args.image_resize_shape,
)

logger.info(f"\033[32mSetting up LoRA\033[0m")
Expand Down Expand Up @@ -130,6 +131,13 @@ def process_data(examples):
parser.add_argument(
"--split", type=str, default="train", help="Split to use for training"
)
parser.add_argument(
"--image-resize-shape",
type=int,
nargs=2,
default=None,
help="Resize images to this shape",
)
parser.add_argument(
"--apply-chat-template",
action="store_false",
Expand Down
9 changes: 8 additions & 1 deletion mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import mlx.core as mx
from PIL import Image
Expand Down Expand Up @@ -205,3 +206,9 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else:
mask = None
return mask


@dataclass
class LanguageModelOutput:
logits: mx.array
cross_attention_states: Optional[List[mx.array]] = None
5 changes: 3 additions & 2 deletions mlx_vlm/models/idefics2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask
from ..base import KVCache, LanguageModelOutput, create_attention_mask


@dataclass
Expand Down Expand Up @@ -163,7 +163,8 @@ def __call__(
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

return self.lm_head(self.norm(h))
logits = self.lm_head(self.norm(h))
return LanguageModelOutput(logits=logits)

def sanitize(self, weights):
# Remove unused precomputed rotary freqs
Expand Down
4 changes: 2 additions & 2 deletions mlx_vlm/models/llava/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask
from ..base import KVCache, LanguageModelOutput, create_attention_mask


@dataclass
Expand Down Expand Up @@ -210,7 +210,7 @@ def __call__(
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
return LanguageModelOutput(logits=out)

@staticmethod
def sanitize(weights):
Expand Down
6 changes: 3 additions & 3 deletions mlx_vlm/models/llava_bunny/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask
from ..base import KVCache, LanguageModelOutput, create_attention_mask


@dataclass
Expand Down Expand Up @@ -200,8 +200,8 @@ def __call__(
inputs_embeds: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None)
return out
out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=mask)
return LanguageModelOutput(logits=out)

def sanitize(self, weights):
if (
Expand Down
5 changes: 3 additions & 2 deletions mlx_vlm/models/llava_next/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask
from ..base import KVCache, LanguageModelOutput, create_attention_mask


@dataclass
Expand Down Expand Up @@ -199,7 +199,8 @@ def __call__(
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out)
logits = self.lm_head(out)
return LanguageModelOutput(logits=logits)

@staticmethod
def sanitize(weights):
Expand Down
8 changes: 8 additions & 0 deletions mlx_vlm/models/mllama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .mllama import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading
Loading