diff --git a/README.md b/README.md index 15c9cca1c..b3404038d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Some more useful examples are listed below. ### Multimodal models - Joint text and image embeddings with [CLIP](clip). +- Text generation from image and text inputs with [LLaVA](llava). ### Other Models diff --git a/llava/.gitignore b/llava/.gitignore new file mode 100644 index 000000000..ab6d270d8 --- /dev/null +++ b/llava/.gitignore @@ -0,0 +1 @@ +**.ipynb \ No newline at end of file diff --git a/llava/README.md b/llava/README.md new file mode 100644 index 000000000..ae58766e1 --- /dev/null +++ b/llava/README.md @@ -0,0 +1,61 @@ +# LLaVA + +An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is +a multimodal model that can generate text given combined image and text inputs. + +## Setup + +Install the dependencies: + +```bash +pip install -r requirements.txt +``` + +## Run + +You can use LLaVA to ask questions about images. + +For example, using the command line: + +```bash +python generate.py \ + --model llava-hf/llava-1.5-7b-hf \ + --image "http://images.cocodataset.org/val2017/000000039769.jpg" \ + --prompt "USER: \nWhat are these?\nASSISTANT:" \ + --max-tokens 128 \ + --temp 0 +``` + +This uses the following image: + +![alt text](http://images.cocodataset.org/val2017/000000039769.jpg) + +And generates the output: + +``` +These are two cats lying on a pink couch. +``` + +You can also use LLaVA in Python: + +```python +from generate import load_model, prepare_inputs, generate_text + +processor, model = load_model("llava-hf/llava-1.5-7b-hf") + +max_tokens, temperature = 128, 0.0 + +prompt = "USER: \nWhat are these?\nASSISTANT:" +image = "http://images.cocodataset.org/val2017/000000039769.jpg" +input_ids, pixel_values = prepare_inputs(processor, image, prompt) + +reply = generate_text( + input_ids, pixel_values, model, processor, max_tokens, temperature +) + +print(reply) +``` + +[^1]: + Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more + information. diff --git a/llava/generate.py b/llava/generate.py new file mode 100644 index 000000000..9535bab93 --- /dev/null +++ b/llava/generate.py @@ -0,0 +1,130 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import codecs +from pathlib import Path + +import mlx.core as mx +import requests +from PIL import Image +from transformers import AutoProcessor + +from llava import LlavaModel + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Generate text from an image using a model." + ) + parser.add_argument( + "--model", + type=str, + default="llava-hf/llava-1.5-7b-hf", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--image", + type=str, + default="http://images.cocodataset.org/val2017/000000039769.jpg", + help="URL or path of the image to process.", + ) + parser.add_argument( + "--prompt", + type=str, + default="USER: \nWhat are these?\nASSISTANT:", + help="Message to be processed by the model.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--temp", type=float, default=0.3, help="Temperature for sampling." + ) + return parser.parse_args() + + +def load_image(image_source): + """ + Helper function to load an image from either a URL or file. + """ + if image_source.startswith(("http://", "https://")): + try: + response = requests.get(image_source, stream=True) + response.raise_for_status() + return Image.open(response.raw) + except Exception as e: + raise ValueError( + f"Failed to load image from URL: {image_source} with error {e}" + ) + elif Path(image_source).is_file(): + try: + return Image.open(image_source) + except IOError as e: + raise ValueError(f"Failed to load image {image_source} with error: {e}") + else: + raise ValueError( + f"The image {image_source} must be a valid URL or existing file." + ) + + +def prepare_inputs(processor, image, prompt): + if isinstance(image, str): + image = load_image(image) + inputs = processor(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + return input_ids, pixel_values + + +def load_model(model_path): + processor = AutoProcessor.from_pretrained(model_path) + model = LlavaModel.from_pretrained(model_path) + return processor, model + + +def sample(logits, temperature=0.0): + if temperature == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temperature)) + + +def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): + + logits, cache = model(input_ids, pixel_values) + logits = logits[:, -1, :] + y = sample(logits, temperature=temperature) + tokens = [y.item()] + + for n in range(max_tokens - 1): + logits, cache = model.language_model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits, temperature) + token = y.item() + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + + return processor.tokenizer.decode(tokens) + + +def main(): + args = parse_arguments() + processor, model = load_model(args.model) + + prompt = codecs.decode(args.prompt, "unicode_escape") + + input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + + print(prompt) + generated_text = generate_text( + input_ids, pixel_values, model, processor, args.max_tokens, args.temp + ) + print(generated_text) + + +if __name__ == "__main__": + main() diff --git a/llava/language.py b/llava/language.py new file mode 100644 index 000000000..e9023b99c --- /dev/null +++ b/llava/language.py @@ -0,0 +1,231 @@ +# Copyright © 2024 Apple Inc. + +import inspect +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = config.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / config.rope_scaling["factor"] + if config.rope_scaling is not None + and config.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if self.repeats > 1: + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = Attention(config) + self.mlp = MLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.config = config + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Llama(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + TransformerBlock(config=config) for _ in range(config.num_hidden_layers) + ] + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + # for passing merged input embeddings + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.model_type = config.model_type + if self.model_type != "llama": + raise ValueError( + f"Model type {self.model_type} not supported. Currently only 'llama' is supported" + ) + self.model = Llama(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + out, cache = self.model(inputs, cache, inputs_embeds) + return self.lm_head(out), cache + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } diff --git a/llava/llava.py b/llava/llava.py new file mode 100644 index 000000000..06e560590 --- /dev/null +++ b/llava/llava.py @@ -0,0 +1,179 @@ +# Copyright © 2024 Apple Inc. + +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from huggingface_hub import snapshot_download +from language import LanguageModel, TextConfig +from vision import VisionConfig, VisionModel + + +@dataclass +class LlaVAConfig: + text_config: TextConfig + vision_config: VisionConfig + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlaVAConfig): + super().__init__() + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class LlavaModel(nn.Module): + def __init__(self, config: LlaVAConfig): + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + ) + + # Select the hidden states from the desired layer + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + "Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}" + ) + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(selected_image_feature) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + return final_inputs_embeds + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids + ): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + + if len(image_positions) != num_images: + raise ValueError( + f"The number of image tokens ({len(image_positions)}) does not " + f" match the number of image inputs ({num_images})." + ) + + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) + + def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits, cache = self.language_model( + input_ids, cache=cache, inputs_embeds=input_embddings + ) + return logits, cache + + @staticmethod + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + + with open(path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = LlaVAConfig.from_dict(model_config) + + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) + + model = LlavaModel(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) + + model.load_weights(list(weights.items())) + return model diff --git a/llava/requirements.txt b/llava/requirements.txt new file mode 100644 index 000000000..a11d91482 --- /dev/null +++ b/llava/requirements.txt @@ -0,0 +1,6 @@ +mlx>=0.5.0 +numpy +transformers +torch +huggingface_hub +Pillow diff --git a/llava/test.py b/llava/test.py new file mode 100644 index 000000000..3cb2863c2 --- /dev/null +++ b/llava/test.py @@ -0,0 +1,162 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +import requests +import torch +from PIL import Image +from transformers import AutoProcessor, LlavaForConditionalGeneration + +from llava import LlavaModel + +MODEL_PATH = "llava-hf/llava-1.5-7b-hf" +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" + + +def load_mlx_models(path): + model = LlavaModel.from_pretrained(path) + model.eval() + return model + + +def load_hf_models(path): + model = LlavaForConditionalGeneration.from_pretrained(path) + model.eval() + return model + + +class TestVisionTower(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) + + def test_image_features(self): + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[ + "pixel_values" + ] + + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) + + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + + self.assertTrue( + mx.allclose( + mx_image_features, + mx.array(hf_image_features.numpy()), + atol=1e-2, + ) + ) + + +class TestLlava(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) + + def test_merge_input_ids_with_image_features(self): + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + values = self.proc(PROMPT, raw_image, return_tensors="pt") + pixel_values = values["pixel_values"] + input_ids = values["input_ids"] + + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) + mx_input_ids = mx.array(input_ids.numpy()) + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens( + mx_input_ids + ) + mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features( + mx_image_features, mx_inputs_embeds, mx_input_ids + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids) + hf_final_embedding, _, _, _ = ( + self.hf_llava._merge_input_ids_with_image_features( + hf_image_features, + hf_inputs_embeds, + input_ids, + attention_mask=input_ids, + labels=torch.ones_like(input_ids), + ) + ) + + self.assertTrue( + mx.allclose( + mx_final_embedding, + mx.array(hf_final_embedding.numpy()), + atol=1e-1, + ) + ) + + def test_generated_tokens(self): + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) + with torch.no_grad(): + hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt") + hf_outputs = self.hf_llava(**hf_inputs) + hf_logits = hf_outputs.logits + + mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np") + pixel_values = mx.array(mx_inputs["pixel_values"]) + input_ids = mx.array(mx_inputs["input_ids"]) + + mx_logits, _ = self.mx_llava(input_ids, pixel_values) + + self.assertTrue( + mx.allclose( + mx_logits[:, -1, :].argmax(axis=-1), + mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1), + atol=1e-2, + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/llava/vision.py b/llava/vision.py new file mode 100644 index 000000000..66287dee6 --- /dev/null +++ b/llava/vision.py @@ -0,0 +1,223 @@ +# Copyright © 2024 Apple Inc. + +import inspect +import math +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding.weight + return embeddings + + +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + + self.model_type = config.model_type + if self.model_type != "clip_vision_model": + raise ValueError(f"Unsupported model type: {self.model_type}") + + self.vision_model = ClipVisionModel(config) + + def __call__( + self, x: mx.array, output_hidden_states: Optional[bool] = None + ) -> mx.array: + return self.vision_model(x, output_hidden_states) + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights