diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD index 4833846..c3dc1e8 100644 --- a/mlx_vlm/LORA.MD +++ b/mlx_vlm/LORA.MD @@ -16,6 +16,7 @@ - Idefics 2 - Deepseek-VL - Paligemma +- Mllama (Llama-3.2-vision) ## Coming Soon - LLaVA-Next diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 3e85b30..07be29b 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -36,6 +36,13 @@ def parse_arguments(): default=DEFAULT_IMAGE, help="URL or path of the image to process.", ) + parser.add_argument( + "--resize-shape", + type=int, + nargs=2, + default=None, + help="Resize shape for the image.", + ) parser.add_argument( "--prompt", type=str, @@ -78,6 +85,13 @@ def main(): prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) + kwargs = {} + if args.resize_shape is not None: + assert ( + len(args.resize_shape) == 2 + ), "Resize shape must be a tuple of two integers" + kwargs["resize_shape"] = args.resize_shape + output = generate( model, processor, @@ -87,6 +101,7 @@ def main(): args.temp, args.max_tokens, args.verbose, + **kwargs, ) if not args.verbose: print(output) diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py index 99822ed..8d3597e 100644 --- a/mlx_vlm/lora.py +++ b/mlx_vlm/lora.py @@ -65,6 +65,7 @@ def process_data(examples): config, processor, image_processor=image_processor, + image_resize_shape=args.image_resize_shape, ) logger.info(f"\033[32mSetting up LoRA\033[0m") @@ -130,6 +131,13 @@ def process_data(examples): parser.add_argument( "--split", type=str, default="train", help="Split to use for training" ) + parser.add_argument( + "--image-resize-shape", + type=int, + nargs=2, + default=None, + help="Resize images to this shape", + ) parser.add_argument( "--apply-chat-template", action="store_false", diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index c01c0b0..d96aa66 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from dataclasses import dataclass +from typing import Any, Dict, List, Optional import mlx.core as mx from PIL import Image @@ -205,3 +206,9 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +@dataclass +class LanguageModelOutput: + logits: mx.array + cross_attention_states: Optional[List[mx.array]] = None diff --git a/mlx_vlm/models/idefics2/language.py b/mlx_vlm/models/idefics2/language.py index 66f1bb8..9011fcd 100644 --- a/mlx_vlm/models/idefics2/language.py +++ b/mlx_vlm/models/idefics2/language.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -163,7 +163,8 @@ def __call__( for layer, c in zip(self.layers, cache): h = layer(h, mask, c) - return self.lm_head(self.norm(h)) + logits = self.lm_head(self.norm(h)) + return LanguageModelOutput(logits=logits) def sanitize(self, weights): # Remove unused precomputed rotary freqs diff --git a/mlx_vlm/models/llava/language.py b/mlx_vlm/models/llava/language.py index a7f11b4..3efef2e 100644 --- a/mlx_vlm/models/llava/language.py +++ b/mlx_vlm/models/llava/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -210,7 +210,7 @@ def __call__( out = self.model.embed_tokens.as_linear(out) else: out = self.lm_head(out) - return out + return LanguageModelOutput(logits=out) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index a5a4fb0..cf7fce8 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -200,8 +200,8 @@ def __call__( inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) - return out + out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=mask) + return LanguageModelOutput(logits=out) def sanitize(self, weights): if ( diff --git a/mlx_vlm/models/llava_next/language.py b/mlx_vlm/models/llava_next/language.py index 497e3b3..9703edd 100644 --- a/mlx_vlm/models/llava_next/language.py +++ b/mlx_vlm/models/llava_next/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -199,7 +199,8 @@ def __call__( mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + logits = self.lm_head(out) + return LanguageModelOutput(logits=logits) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/mllama/__init__.py b/mlx_vlm/models/mllama/__init__.py new file mode 100644 index 0000000..929eda9 --- /dev/null +++ b/mlx_vlm/models/mllama/__init__.py @@ -0,0 +1,8 @@ +from .mllama import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, + VisionModel, +) diff --git a/mlx_vlm/models/mllama/language.py b/mlx_vlm/models/mllama/language.py new file mode 100644 index 0000000..0b53b39 --- /dev/null +++ b/mlx_vlm/models/mllama/language.py @@ -0,0 +1,416 @@ +import inspect +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from ..base import KVCache, LanguageModelOutput, create_attention_mask + + +@dataclass +class TextConfig: + model_type: str = "mllama" + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 14336 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + hidden_act: str = "silu" + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + cross_attention_layers: List[int] = field( + default_factory=lambda: [3, 8, 13, 18, 23, 28, 33, 38] + ) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class MllamaTextCrossAttention(nn.Module): + def __init__(self, config: TextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + self.scale = self.head_dim**-0.5 + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def __call__( + self, + hidden_states: mx.array, + cross_attention_states: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + + bsz, q_len, _ = hidden_states.shape + query_states = ( + self.q_proj(hidden_states) + .reshape(bsz, q_len, self.num_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = ( + self.k_proj(cross_attention_states) + .reshape(bsz, -1, self.num_key_value_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + value_states = ( + self.v_proj(cross_attention_states) + .reshape(bsz, -1, self.num_key_value_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + key_states = self.k_norm(key_states) + if cache is not None: + key_states, value_states = cache.update_and_fetch( + key_states, value_states + ) + else: + raise ValueError( + "Cross attention states must be provided for cross attention layer." + ) + + attn_output = mx.fast.scaled_dot_product_attention( + query_states, + key_states, + value_states, + scale=self.scale, + mask=attention_mask, # add a dim for batch processing + ) + attn_output = attn_output.transpose(0, 2, 1, 3).reshape( + bsz, q_len, self.hidden_size + ) + return self.o_proj(attn_output) + + +class MllamaTextSelfAttention(nn.Module): + def __init__(self, config: TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scale = self.head_dim**-0.5 + self.layer_idx = layer_idx + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rope = nn.RoPE( + self.head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=1, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + bsz, q_len, _ = x.shape + query_states = ( + self.q_proj(x).reshape(bsz, q_len, self.num_heads, -1).transpose(0, 2, 1, 3) + ) + key_states = ( + self.k_proj(x) + .reshape(bsz, q_len, self.num_key_value_heads, -1) + .transpose(0, 2, 1, 3) + ) + value_states = ( + self.v_proj(x) + .reshape(bsz, q_len, self.num_key_value_heads, -1) + .transpose(0, 2, 1, 3) + ) + + if cache is not None: + query_states = self.rope(query_states, offset=cache.offset) + key_states = self.rope(key_states, offset=cache.offset) + key_states, value_states = cache.update_and_fetch(key_states, value_states) + else: + query_states = self.rope(query_states) + key_states = self.rope(key_states) + + attn_output = mx.fast.scaled_dot_product_attention( + query_states, key_states, value_states, scale=self.scale, mask=mask + ) + attn_output = attn_output.transpose(0, 2, 1, 3).reshape( + bsz, q_len, self.hidden_size + ) + return self.o_proj(attn_output) + + +class MllamaTextMLP(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.gate_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + config.intermediate_size, config.hidden_size, bias=False + ) + self.act_fn = lambda x: x * mx.sigmoid(x) + + def __call__(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class MllamaSelfAttentionDecoderLayer(nn.Module): + def __init__(self, config: TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MllamaTextSelfAttention(config, layer_idx=layer_idx) + self.mlp = MllamaTextMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + x=hidden_states, + mask=mask, + cache=cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MllamaCrossAttentionDecoderLayer(nn.Module): + def __init__(self, config: TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx) + self.mlp = MllamaTextMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = mx.zeros(1) + self.cross_attn_mlp_gate = mx.zeros(1) + + def __call__( + self, + hidden_states: mx.array, + cross_attention_states: mx.array, + attention_mask: Optional[mx.array] = None, + full_text_row_masked_out_mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.cross_attn( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + cache=cache, + ) + hidden_states = residual + mx.tanh(self.cross_attn_attn_gate) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states + hidden_states = residual + mx.tanh(self.cross_attn_mlp_gate) * hidden_states + + return hidden_states + + +class MllamaTextModel(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + + self.embed_tokens = nn.Embedding(config.vocab_size + 8, config.hidden_size) + self.layers = [ + ( + MllamaCrossAttentionDecoderLayer(config, layer_idx) + if layer_idx in config.cross_attention_layers + else MllamaSelfAttentionDecoderLayer(config, layer_idx) + ) + for layer_idx in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + cross_attention_states: Optional[mx.array] = None, + cross_attention_mask: Optional[mx.array] = None, + full_text_row_masked_out_mask: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + position_ids = mx.expand_dims(mx.arange(seq_length), 0) + position_ids = mx.repeat(position_ids, batch_size, axis=0) + + hidden_states = inputs_embeds + + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(hidden_states) + + for idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)): + if idx in self.config.cross_attention_layers: + if cross_attention_states is None: + raise ValueError( + f"Cross attention states must be provided for layer {idx}" + ) + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + cache=c, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + mask=mask, + cache=c, + ) + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.model = MllamaTextModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + cross_attention_states: Optional[mx.array] = None, + cross_attention_mask: Optional[mx.array] = None, + full_text_row_masked_out_mask: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> Tuple[mx.array, Optional[mx.array]]: + + hidden_states = self.model( + input_ids=input_ids, + mask=mask, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + inputs_embeds=inputs_embeds, + cache=cache, + ) + + logits = self.lm_head(hidden_states) + + return LanguageModelOutput( + logits=logits, cross_attention_states=cross_attention_states + ) + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.config.hidden_size // self.config.num_attention_heads + + @property + def n_kv_heads(self): + return self.config.num_key_value_heads diff --git a/mlx_vlm/models/mllama/mllama.py b/mlx_vlm/models/mllama/mllama.py new file mode 100644 index 0000000..4bb8bc2 --- /dev/null +++ b/mlx_vlm/models/mllama/mllama.py @@ -0,0 +1,207 @@ +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download + +from ..base import KVCache +from .language import LanguageModel, TextConfig +from .vision import VisionConfig, VisionModel + + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + ignore_index: int = -100 + image_token_index: int = 128256 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache: Optional[KVCache] = None, + **kwargs, + ) -> Tuple[mx.array, Optional[mx.array]]: + + aspect_ratio_ids = kwargs.pop("aspect_ratio_ids", None) + aspect_ratio_mask = kwargs.pop("aspect_ratio_mask", None) + cross_attention_mask = kwargs.pop("cross_attention_mask", None) + + inputs_embeds = None + + # Process vision input if provided + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + + vision_outputs = self.vision_tower( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = vision_outputs[0] + + cross_attention_states = self.multi_modal_projector( + cross_attention_states + ).reshape( + -1, + cross_attention_states.shape[-2], + self.config.text_config.hidden_size, + ) + + else: + cross_attention_states = None + + # Prepare cross attention mask + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = ( + self._prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + ** 2 + + 1, + ) + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None: + cache_position = mx.arange(input_ids.shape[1], dtype=mx.int32) + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, cache_position + ] + + # Process language input + outputs = self.language_model( + input_ids=input_ids, + mask=mask, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + inputs_embeds=inputs_embeds, + cache=cache, + ) + + return outputs + + def _prepare_cross_attention_mask( + self, + cross_attention_mask: mx.array, + num_vision_tokens: int, + ) -> Tuple[mx.array, mx.array]: + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = mx.repeat( + cross_attention_mask, num_vision_tokens, axis=3 + ) + cross_attention_mask = cross_attention_mask.reshape( + batch_size, text_total_length, -1 + ) + cross_attention_mask = mx.expand_dims(cross_attention_mask, 1) + + # Invert the mask + inverted_cross_attn_mask = 1.0 - cross_attention_mask + fill_array = mx.array(-1e9) + fill_array = mx.broadcast_to(fill_array, inverted_cross_attn_mask.shape) + cross_attention_mask = mx.where( + inverted_cross_attn_mask, + fill_array, + cross_attention_mask, + ) + + # Apply full-row bias + full_text_row_masked_out_mask = mx.any( + cross_attention_mask != -1e9, + axis=-1, + keepdims=True, + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + @staticmethod + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + + with open(path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = ModelConfig.from_dict(model_config) + + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config) + + model = Model(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) + + model.load_weights(list(weights.items())) + return model + + def sanitize(self, weights): + def transform_key(key): + if "vision_tower" not in key: + key = key.replace("vision_model", "vision_tower") + return key + + return {transform_key(k): v for k, v in weights.items()} diff --git a/mlx_vlm/models/mllama/vision.py b/mlx_vlm/models/mllama/vision.py new file mode 100644 index 0000000..4fa2582 --- /dev/null +++ b/mlx_vlm/models/mllama/vision.py @@ -0,0 +1,499 @@ +import inspect +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class VisionConfig: + image_size: int = 560 + patch_size: int = 14 + num_channels: int = 3 + hidden_size: int = 1280 + intermediate_size: int = 5120 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + max_num_tiles: int = 4 + max_aspect_ratio_id: int = 8 + num_global_layers: int = 8 + norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + vision_output_dim: int = 7680 + intermediate_layers_indices: List[int] = field( + default_factory=lambda: [3, 7, 15, 23, 30] + ) + supported_aspect_ratios: Tuple[List[int]] = ( + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [2, 1], + [2, 2], + [3, 1], + [4, 1], + ) + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def check_array_shape(arr): + shape = arr.shape + + # Check if the shape has 4 dimensions + if len(shape) != 4: + return False + + out_channels, kH, KW, _ = shape + + # Check if out_channels is the largest, and kH and KW are the same + if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): + return True + else: + return False + + +class MllamaVisionAttention(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.embed_dim, bias=False + ) + + def __call__( + self, + hidden_state: mx.array, + attention_mask: Optional[mx.array] = None, + ) -> mx.array: + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.reshape( + batch_size, q_seq_len, self.num_heads, self.head_dim + ).transpose(0, 2, 1, 3) + key = key.reshape( + batch_size, kv_seq_len, self.num_heads, self.head_dim + ).transpose(0, 2, 1, 3) + value = value.reshape( + batch_size, kv_seq_len, self.num_heads, self.head_dim + ).transpose(0, 2, 1, 3) + + if attention_mask is not None: + attention_mask = attention_mask[:, :, : key.shape[-2], :] + + attn_output = mx.fast.scaled_dot_product_attention( + query, key, value, scale=self.scale, mask=attention_mask + ) + + attn_output = attn_output.transpose(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + return self.o_proj(attn_output) + + +class MllamaVisionMLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + self.gelu = nn.GELU() + + def __call__(self, hidden_states: mx.array) -> mx.array: + hidden_states = self.fc1(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, config: VisionConfig, is_gated: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.is_gated = is_gated + + self.self_attn = MllamaVisionAttention(config) + self.mlp = MllamaVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=config.norm_eps + ) + + if is_gated: + self.gate_attn = mx.zeros(1) + self.gate_ffn = mx.zeros(1) + + def __call__( + self, + hidden_state: mx.array, + attention_mask: Optional[mx.array] = None, + ) -> mx.array: + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + if self.is_gated: + hidden_state = mx.tanh(self.gate_attn) * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = mx.tanh(self.gate_ffn) * hidden_state + hidden_state = residual + hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__(self, config: VisionConfig, num_layers=32, is_gated=False): + super().__init__() + self.layers = [ + MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers) + ] + + def __call__( + self, + hidden_states: mx.array, + attention_mask: Optional[mx.array] = None, + ) -> Tuple[mx.array, List[mx.array]]: + encoder_states = () + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask=attention_mask) + encoder_states = encoder_states + (hidden_states,) + return hidden_states, encoder_states + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, config: VisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size + ) + if is_gated: + self.gate = mx.zeros(1) + + def __call__(self, hidden_state: mx.array, aspect_ratio_ids: mx.array) -> mx.array: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * mx.tanh(self.gate) + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = mx.zeros(1) + + # position embedding + self.embedding = ( + mx.random.normal((self.num_patches, self.hidden_size)) * self.scale + ) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size, + ) + + def __call__(self, hidden_state: mx.array, aspect_ratio_ids: mx.array) -> mx.array: + # position embeddings + gated_position_embedding = (1 - mx.tanh(self.gate)) * self.embedding + hidden_state = hidden_state + gated_position_embedding.reshape( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = mx.tanh(self.gate) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = mx.random.normal((self.hidden_size,)) * self.scale + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + + # encoders + self.transformer = MllamaVisionEncoder( + config, config.num_hidden_layers, is_gated=False + ) + self.global_transformer = MllamaVisionEncoder( + config, config.num_global_layers, is_gated=True + ) + + def __call__( + self, + pixel_values: mx.array, + aspect_ratio_ids: mx.array, + aspect_ratio_mask: mx.array, + ) -> mx.array: + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.moveaxis(1, 3)).moveaxis(3, 1) + + hidden_state = patch_embeds.reshape( + patch_embeds.shape[0], patch_embeds.shape[1], -1 + ).transpose(0, 2, 1) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + class_embedding = mx.broadcast_to( + self.class_embedding, + (batch_size * num_concurrent_media * num_tiles, 1, dim), + ) + hidden_state = mx.concatenate([class_embedding, hidden_state], axis=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + + # Pad the tensor + padding = [(0, 0), (0, 0), (0, num_padding_patches), (0, 0)] + hidden_state = mx.pad(hidden_state, padding) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + ) + + # Apply encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, -1, self.hidden_size + ) + output = self.transformer(hidden_state, attention_mask=attention_mask) + + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + self.hidden_size, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + self.hidden_size, + ) + global_output = self.global_transformer( + hidden_state, attention_mask=attention_mask + ) + + hidden_state = global_output[0] + + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = mx.stack(all_intermediate_hidden_states, axis=-1) + intermediate_hidden_states = intermediate_hidden_states[ + ..., self.intermediate_layers_indices + ] + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = mx.concatenate( + [hidden_state, intermediate_hidden_states], axis=-1 + ) + + return hidden_state + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + if check_array_shape(v): + sanitized_weights[k] = v + else: + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: mx.array, + num_patches: int, + target_length: int, +) -> mx.array: + dtype = mx.float32 + aspect_ratio_mask = aspect_ratio_mask.astype(dtype) + + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.reshape(batch_size, max_num_tiles, 1, 1).astype( + dtype + ) + attention_mask = mx.tile(attention_mask, (1, 1, target_length, 1)) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + + min_value = -1e9 + attention_mask = attention_mask @ attention_mask.transpose(0, 2, 1) * min_value + attention_mask = attention_mask[:, None, :, :] + + return attention_mask diff --git a/mlx_vlm/models/multi_modality/language.py b/mlx_vlm/models/multi_modality/language.py index 22a85d8..7ae0283 100644 --- a/mlx_vlm/models/multi_modality/language.py +++ b/mlx_vlm/models/multi_modality/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -199,7 +199,8 @@ def __call__( mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + logits = self.lm_head(out) + return LanguageModelOutput(logits=logits) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/paligemma/language.py b/mlx_vlm/models/paligemma/language.py index 96a32ad..7e0fd1b 100644 --- a/mlx_vlm/models/paligemma/language.py +++ b/mlx_vlm/models/paligemma/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -186,7 +186,7 @@ def __call__( ): out = self.model(inputs, cache, inputs_embeds=inputs_embeds, mask=mask) out = self.model.embed_tokens.as_linear(out) - return out + return LanguageModelOutput(logits=out) def sanitize(self, weights): return { diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 01285f5..6f2d412 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -8,7 +8,7 @@ import mlx.nn as nn import numpy as np -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask from .language import LanguageModel, TextConfig from .su_rope import Phi3SuScaledRotaryEmbedding from .vision import VisionConfig, VisionModel @@ -213,7 +213,8 @@ def __call__( **kwargs, ): out = self.model(inputs, pixel_values, image_sizes, cache) - return self.lm_head(out).astype(self.lm_head.weight.dtype) + logits = self.lm_head(out).astype(self.lm_head.weight.dtype) + return LanguageModelOutput(logits=logits) @property def layers(self): diff --git a/mlx_vlm/models/pixtral/language.py b/mlx_vlm/models/pixtral/language.py index da8482c..8500646 100644 --- a/mlx_vlm/models/pixtral/language.py +++ b/mlx_vlm/models/pixtral/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -198,7 +198,8 @@ def __call__( mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + logits = self.lm_head(out) + return LanguageModelOutput(logits=logits) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index ef2518d..15e21b7 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -6,7 +6,7 @@ import mlx.nn as nn import numpy as np -from ..base import KVCache, create_attention_mask +from ..base import KVCache, LanguageModelOutput, create_attention_mask @dataclass @@ -301,7 +301,7 @@ def __call__( out = self.model.embed_tokens.as_linear(out) else: out = self.lm_head(out) - return out + return LanguageModelOutput(logits=out) @property def layers(self): diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index fe4ce6c..62416b0 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -3,11 +3,10 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import mlx.core as mx import mlx.nn as nn -import numpy as np from huggingface_hub import snapshot_download from .language import LanguageModel, TextConfig @@ -77,8 +76,13 @@ def _merge_input_ids_with_image_features( # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_index - image_indices = np.where(image_positions)[1].tolist() - inputs_embeds[:, image_indices, :] = image_features.astype(mx.float32) + # image_indices = mx.where(image_positions) + image_features = image_features.astype(mx.float32) + pad_size = inputs_embeds.shape[1] - image_features.shape[1] + image_features = mx.pad(image_features, ((0, 0), (0, pad_size), (0, 0))) + inputs_embeds = mx.where( + image_positions[:, :, None], image_features, inputs_embeds + ) # TODO: Add video features diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 1289fed..c8a8a57 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -65,6 +65,7 @@ def add_image_tokens(message, token_format): "multi_modality": "message_with_image_token", "pixtral": "message_list_with_image_type", "paligemma": "prompt_only", + "mllama": "message_list_with_image", } if num_images > 1 and model_name in [ @@ -73,6 +74,7 @@ def add_image_tokens(message, token_format): "bunny-llama", "paligemma", "multi_modality", + "mllama", ]: raise ValueError( f"Model {model_name} does not support multi-image chat. Please only use 1 image." diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 79b13ba..6079a05 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -17,12 +17,14 @@ def language_test_runner(self, model, model_type, vocab_size, num_layers): inputs = mx.array([[0, 1]]) outputs = model(inputs) - self.assertEqual(outputs.shape, (batch_size, 2, vocab_size)) - self.assertEqual(outputs.dtype, t) + logits = outputs.logits + self.assertEqual(logits.shape, (batch_size, 2, vocab_size)) + self.assertEqual(logits.dtype, t) - outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=None) - self.assertEqual(outputs.shape, (batch_size, 1, vocab_size)) - self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(logits[0, -1:, :], keepdims=True), cache=None) + logits = outputs.logits + self.assertEqual(logits.shape, (batch_size, 1, vocab_size)) + self.assertEqual(logits.dtype, t) def mm_projector_test_runner( self, mm_projector, vision_hidden_size, text_hidden_size @@ -722,6 +724,76 @@ def test_qwen2_vl(self): grid_thw=mx.ones((1, 3)), # image temporals shape (num_images, 3) ) + def test_mllama(self): + from mlx_vlm.models import mllama + + vision_config = mllama.VisionConfig( + image_size=50, + patch_size=14, + num_channels=3, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=10, + num_attention_heads=16, + max_num_tiles=4, + max_aspect_ratio_id=8, + num_global_layers=8, + norm_eps=1e-5, + attention_dropout=0.0, + hidden_dropout=0.0, + vision_output_dim=7680, + intermediate_layers_indices=[3, 7, 15, 23, 30], + ) + + text_config = mllama.TextConfig( + model_type="mllama", + hidden_size=4096, + num_hidden_layers=10, + intermediate_size=14336, + num_attention_heads=16, + rms_norm_eps=1e-6, + vocab_size=32000, + ) + + model_config = mllama.ModelConfig( + text_config=text_config, + vision_config=vision_config, + model_type="mllama", + ignore_index=-100, + image_token_index=128256, + vision_feature_select_strategy="default", + vision_feature_layer=-2, + vocab_size=32000, + ) + + # Create the model + model = mllama.Model(model_config) + + # Create dummy input data + batch_size = 1 + seq_length = 5 + num_tiles = 4 + input_ids = mx.random.randint(0, 1000, (batch_size, seq_length)) + pixel_values = mx.random.normal((batch_size, 1, num_tiles, 3, 50, 50)) + mask = mx.ones((batch_size, seq_length)) + aspect_ratio_ids = mx.zeros((batch_size, 1), dtype=mx.int32) + aspect_ratio_mask = mx.ones((batch_size, 1, num_tiles)) + cross_attention_mask = mx.ones((batch_size, seq_length, 1, num_tiles)) + + # Forward pass + output = model( + input_ids=input_ids, + pixel_values=pixel_values, + mask=mask, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + cross_attention_mask=cross_attention_mask, + ) + + # Check output shape + expected_shape = (batch_size, seq_length, model_config.vocab_size) + self.assertEqual(output.logits.shape, expected_shape) + if __name__ == "__main__": unittest.main() diff --git a/mlx_vlm/tests/test_trainer.py b/mlx_vlm/tests/test_trainer.py index 519bfe4..97339e7 100644 --- a/mlx_vlm/tests/test_trainer.py +++ b/mlx_vlm/tests/test_trainer.py @@ -1,11 +1,10 @@ import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import mlx.core as mx import mlx.nn as nn -from mlx_vlm.trainer.trainer import Dataset, Trainer, TrainingArgs -from mlx_vlm.utils import prepare_inputs +from mlx_vlm.trainer.trainer import Dataset, Trainer class TestDataset(unittest.TestCase): @@ -15,8 +14,7 @@ def setUp(self): self.mock_processor = MagicMock() self.mock_image_processor = MagicMock() - @patch("mlx_vlm.utils.prepare_inputs") - def test_dataset_initialization(self, mock_prepare_inputs): + def test_dataset_initialization(self): dataset = Dataset( self.mock_hf_dataset, self.mock_config, @@ -93,18 +91,23 @@ def test_trainer_initialization(self): self.assertFalse(self.trainer.train_on_completions) self.assertEqual(self.trainer.assistant_id, 77091) - @patch("mlx.nn.losses.cross_entropy") - def test_loss_fn(self, mock_cross_entropy): + def test_loss_fn(self): batch = { "pixel_values": mx.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), "input_ids": mx.array([[1, 2, 3], [4, 5, 6]]), "attention_mask": mx.array([[1, 1, 1], [1, 1, 0]]), "image_grid_thw": (1, 1, 1), "image_sizes": [224, 224], + "aspect_ratio_ids": mx.array([[1, 2], [3, 4]]), + "aspect_ratio_mask": mx.array([[1, 1], [1, 0]]), + "cross_attention_mask": mx.array([[1, 1], [1, 0]]), } - self.mock_model.return_value = mx.array([[[0.1, 0.2, 0.3]], [[0.4, 0.5, 0.6]]]) - mock_cross_entropy.return_value = mx.array([[0.1, 0.2], [0.3, 0.4]]) + mock_logits = mx.array([[[0.1, 0.2, 0.3]], [[0.4, 0.5, 0.6]]]) + # Create a mock LanguageModelOutput with the logits + mock_output = Mock() + mock_output.logits = mock_logits + self.mock_model.return_value = mock_output loss = self.trainer.loss_fn(self.mock_model, batch) diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index f02674e..213b0ef 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -1,6 +1,4 @@ import json -import os -import time import warnings from dataclasses import dataclass, field from pathlib import Path @@ -9,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_map def get_prompt(model_type, processor, conversation): @@ -41,6 +39,7 @@ def __init__( image_processor=None, take=None, split=None, + image_resize_shape=None, ): if split is not None: self.dataset = hf_dataset[split] @@ -51,6 +50,7 @@ def __init__( self.processor = processor self.config = config self.image_processor = image_processor + self.image_resize_shape = image_resize_shape def __len__(self): return len(self.dataset) @@ -89,10 +89,27 @@ def __getitem__(self, idx): image_token_index = self.config["image_token_index"] inputs = prepare_inputs( - self.image_processor, self.processor, images, prompts, image_token_index + self.image_processor, + self.processor, + images, + prompts, + image_token_index, + self.image_resize_shape, ) input_ids, pixel_values, mask = inputs[:3] - kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} + kwargs = { + k: v + for k, v in zip( + [ + "image_grid_thw", + "image_sizes", + "aspect_ratio_ids", + "aspect_ratio_mask", + "cross_attention_mask", + ], + inputs[3:], + ) + } if mask is None: mask = mx.ones_like(input_ids) @@ -168,12 +185,18 @@ def default_loss(model, inputs, targets, lengths): class Trainer: def __init__( - self, model, optimizer, train_on_completions=False, assistant_id=77091 + self, + model, + optimizer, + train_on_completions=False, + assistant_id=77091, + clip_gradients=None, ): self.model = model self.optimizer = optimizer self.train_on_completions = train_on_completions self.assistant_id = assistant_id + self.clip_gradients = clip_gradients def loss_fn(self, model, batch): pixel_values = batch["pixel_values"] @@ -203,20 +226,22 @@ def loss_fn(self, model, batch): input_ids = input_ids[:, :-1] - kwargs = ( - { - "image_grid_thw": batch["image_grid_thw"], - "image_sizes": batch["image_sizes"], - } - if "image_grid_thw" in batch or "image_sizes" in batch - else {} - ) + kwargs = {} + image_keys = [ + "image_grid_thw", + "image_sizes", + "aspect_ratio_ids", + "aspect_ratio_mask", + "cross_attention_mask", + ] + if any(key in batch for key in image_keys): + kwargs = {key: batch[key] for key in image_keys if key in batch} # Forward pass - logits = model(input_ids, pixel_values, attention_mask, **kwargs) + outputs = model(input_ids, pixel_values, attention_mask, **kwargs) # Cast to float32 - logits.astype(mx.float32) + logits = outputs.logits.astype(mx.float32) # Ensure logits and labels have the same sequence length def align_logits_with_labels(logits, labels): @@ -249,6 +274,13 @@ def align_logits_with_labels(logits, labels): def train_step(self, batch): loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) loss, grads = loss_and_grad_fn(self.model, batch) + + # Add gradient clipping + if self.clip_gradients is not None: + grads = tree_map( + lambda g: mx.clip(g, -self.clip_gradients, self.clip_gradients), grads + ) + self.optimizer.update(self.model, grads) return loss diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py index 9873ed7..6f2a898 100644 --- a/mlx_vlm/trainer/utils.py +++ b/mlx_vlm/trainer/utils.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import mlx.nn as nn @@ -140,9 +141,18 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: if not adapter_path.exists(): raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + # Check if the adapter has lora params in the config (adapter_config.json) + with open(adapter_path / "adapter_config.json", "r") as f: + config = json.load(f) + if "rank" not in config: + raise ValueError("The adapter does not have lora params in the config") + # TODO: add lora params to the config and load them here list_of_modules = find_all_linear_names(model.language_model.model) - model = get_peft_model(model, list_of_modules) + if config is not None: + model = get_peft_model(model, list_of_modules, **config) + else: + model = get_peft_model(model, list_of_modules) # TODO: Use custom adapter name model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 186bd01..4756524 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -717,10 +717,10 @@ def resize_image(img, max_size): return img.resize(new_size) -def process_image(img, resize_shape): +def process_image(img, resize_shape, image_processor): if isinstance(img, str): img = load_image(img) - if resize_shape is not None: + if resize_shape is not None and image_processor is None: img = resize_image(img, resize_shape) return img @@ -735,13 +735,14 @@ def prepare_inputs( images = [images] # Process images - images = [ - process_image(img, resize_shape) if isinstance(img, str) else img - for img in images - ] + images = [process_image(img, resize_shape, image_processor) for img in images] image_grid_thw = None image_sizes = None + aspect_ratio_ids = None + aspect_ratio_mask = None + cross_attention_mask = None + if image_processor is not None: if not isinstance(prompts, list): prompts = [prompts] @@ -790,7 +791,28 @@ def prepare_inputs( if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) - return input_ids, pixel_values, mask, image_grid_thw, image_sizes + aspect_ratio_ids = inputs.get("aspect_ratio_ids", None) + if aspect_ratio_ids is not None: + aspect_ratio_ids = mx.array(aspect_ratio_ids) + + aspect_ratio_mask = inputs.get("aspect_ratio_mask", None) + if aspect_ratio_mask is not None: + aspect_ratio_mask = mx.array(aspect_ratio_mask) + + cross_attention_mask = inputs.get("cross_attention_mask", None) + if cross_attention_mask is not None: + cross_attention_mask = mx.array(cross_attention_mask) + + return ( + input_ids, + pixel_values, + mask, + image_grid_thw, + image_sizes, + aspect_ratio_ids, + aspect_ratio_mask, + cross_attention_mask, + ) def generate_step( @@ -866,10 +888,15 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: if repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] - def _step(y): + def _step(y, **kwargs): nonlocal repetition_context - logits = model.language_model(y[None], cache=cache, mask=mask) - logits = logits[:, -1, :] + outputs = model.language_model( + y[None], + cache=cache, + mask=mask, + **kwargs, + ) + logits = outputs.logits[:, -1, :] if repetition_penalty: logits = apply_repetition_penalty( @@ -885,12 +912,22 @@ def _step(y): repetition_context = repetition_context[-repetition_context_size:] return y, logprobs.squeeze(0) - logits = model(input_ids, pixel_values, cache=cache, mask=mask, **kwargs) - logits = logits[:, -1, :] + outputs = model(input_ids, pixel_values, cache=cache, mask=mask, **kwargs) + if outputs.cross_attention_states is not None: + kwargs = { + k: v + for k, v in zip( + ["cross_attention_states"], [outputs.cross_attention_states] + ) + } + else: + kwargs = {} + + logits = outputs.logits[:, -1, :] y, logprobs = sample(logits) mx.async_eval(y) while True: - next_y, next_logprobs = _step(y) + next_y, next_logprobs = _step(y, **kwargs) mx.async_eval(next_y) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -924,9 +961,12 @@ def stream_generate( else: tokenizer = processor.tokenizer + resize_shape = kwargs.pop("resize_shape", None) image_token_index = model.config.image_token_index + + # Prepare inputs inputs = prepare_inputs( - image_processor, processor, image, prompt, image_token_index + image_processor, processor, image, prompt, image_token_index, resize_shape ) input_ids, pixel_values, mask = inputs[:3] kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} @@ -962,6 +1002,7 @@ def generate( repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: float = 1.0, + **kwargs, ) -> str: """ Generate text from the model. @@ -992,13 +1033,27 @@ def generate( prompt_tokens = mx.array(processor.tokenizer.encode(prompt)) tokenizer = processor.tokenizer + resize_shape = kwargs.pop("resize_shape", None) image_token_index = model.config.image_token_index + # Prepare inputs inputs = prepare_inputs( - image_processor, processor, image, prompt, image_token_index + image_processor, processor, image, prompt, image_token_index, resize_shape ) input_ids, pixel_values, mask = inputs[:3] - kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} + kwargs = { + k: v + for k, v in zip( + [ + "image_grid_thw", + "image_sizes", + "aspect_ratio_ids", + "aspect_ratio_mask", + "cross_attention_mask", + ], + inputs[3:], + ) + } # Initialize timing and detokenizer tic = time.perf_counter()