From 4d964bda25b787143782be3c84f59f7fb4fb2234 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 19 Feb 2024 12:03:10 -0500 Subject: [PATCH 01/34] add: llava mlx first draft --- llava/Local LLava.ipynb | 188 +++++++++++++++++++ llava/README.md | 1 + llava/clip.py | 292 +++++++++++++++++++++++++++++ llava/convert.py | 87 +++++++++ llava/llama.py | 404 ++++++++++++++++++++++++++++++++++++++++ llava/llava.py | 88 +++++++++ llava/requirements.txt | 6 + llava/utils.py | 87 +++++++++ 8 files changed, 1153 insertions(+) create mode 100644 llava/Local LLava.ipynb create mode 100644 llava/README.md create mode 100644 llava/clip.py create mode 100644 llava/convert.py create mode 100644 llava/llama.py create mode 100644 llava/llava.py create mode 100644 llava/requirements.txt create mode 100644 llava/utils.py diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb new file mode 100644 index 000000000..bc8252628 --- /dev/null +++ b/llava/Local LLava.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "import os\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "mlx_path = Path('mlx_model')\n", + "\n", + "if not os.path.exists(mlx_path):\n", + " os.makedirs(mlx_path)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 207126.12it/s]\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "import mlx.core as mx\n", + "from convert import get_model_path, fetch_from_hub, hf_repo\n", + "\n", + "\n", + "model_path = get_model_path(hf_repo)\n", + "model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Converting\n", + "[INFO] Saving\n" + ] + } + ], + "source": [ + "from utils import map_weights, should_keep_weight\n", + "\n", + "\n", + "print(\"[INFO] Converting\")\n", + "mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n", + "mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n", + "print(\"[INFO] Saving\")\n", + "mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n", + "for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n", + " if fn in os.listdir(model_path):\n", + " shutil.copyfile(\n", + " str(model_path / f\"{fn}\"),\n", + " str(mlx_path / f\"{fn}\"),\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel\n", + "\n", + "llava_mlx_config = LlaVAConfig(\n", + " llm_config=LLMConfig(\n", + " dim=4096,\n", + " n_layers=32,\n", + " head_dim=4096,\n", + " hidden_dim=11008,\n", + " norm_eps=1e-5,\n", + " n_heads=32, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L14. But only works with 1. Please see llama file for how heads are split. Is this wrong?\n", + " n_kv_heads=32, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L16\n", + " vocab_size=32064,\n", + " rope_theta=0,\n", + " rope_traditional=False\n", + " ),\n", + " vision_config=VisionConfig(\n", + " num_hidden_layers=24,\n", + " hidden_size=1024,\n", + " intermediate_size=4096,\n", + " num_attention_heads=16,\n", + " num_channels=3,\n", + " image_size=336,\n", + " patch_size=14\n", + " ),\n", + " projection_config=ProjectionConfig(\n", + " in_features=1024,\n", + " out_features=4096\n", + " )\n", + ")\n", + "\n", + "\n", + "model = LlavaModel(llava_mlx_config)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Expected shape (131072, 4096) but received shape (4096, 4096) for parameter language_model.layers.0.attention.wq.weight", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_weights\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmlx_model/weights.npz\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/mlx/lib/python3.10/site-packages/mlx/nn/layers/base.py:176\u001b[0m, in \u001b[0;36mModule.load_weights\u001b[0;34m(self, file_or_weights, strict)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 172\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected mx.array but received \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(v_new)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 174\u001b[0m )\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v_new\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m v\u001b[38;5;241m.\u001b[39mshape:\n\u001b[0;32m--> 176\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 177\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected shape \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m but received \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m shape \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv_new\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 179\u001b[0m )\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate(tree_unflatten(weights))\n", + "\u001b[0;31mValueError\u001b[0m: Expected shape (131072, 4096) but received shape (4096, 4096) for parameter language_model.layers.0.attention.wq.weight" + ] + } + ], + "source": [ + "model.load_weights('mlx_model/weights.npz')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: load images, and test generate " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llava/README.md b/llava/README.md new file mode 100644 index 000000000..c10506421 --- /dev/null +++ b/llava/README.md @@ -0,0 +1 @@ +# LLaVA diff --git a/llava/clip.py b/llava/clip.py new file mode 100644 index 000000000..6c46088d0 --- /dev/null +++ b/llava/clip.py @@ -0,0 +1,292 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.core import linalg as LA +from mlx.nn.losses import cross_entropy +from mlx.utils import tree_flatten + + +@dataclass +class CLIPVisionOutput: + pooler_output: mx.array + last_hidden_state: mx.array + + +@dataclass +class CLIPTextOutput: + pooler_output: mx.array + last_hidden_state: mx.array + + +@dataclass +class CLIPModelOutput: + loss: Optional[mx.array] + text_embeds: Optional[mx.array] + image_embeds: Optional[mx.array] + text_model_output: CLIPTextOutput + vision_model_output: CLIPVisionOutput + + +@dataclass +class CLIPTextConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + max_position_embeddings: int + vocab_size: int + + +@dataclass +class CLIPVisionConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_channels: int + image_size: int + patch_size: int + + +@dataclass +class CLIPConfig: + text_config: CLIPTextConfig + vision_config: CLIPVisionConfig + projection_dim: int + + +def quick_gelu(x: mx.array) -> mx.array: + """ + A fast GELU approximation https://github.com/hendrycks/GELUs + """ + return x * mx.sigmoid(1.702 * x) + + +def clip_loss(logits: mx.array) -> mx.array: + N, M = logits.shape + caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean") + image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean") + return (caption_loss + image_loss) / 2.0 + + +class CLIPEncoderLayer(nn.TransformerEncoderLayer): + """The transformer encoder layer from CLIP.""" + + def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): + super().__init__( + dims=hidden_dim, + mlp_dims=intermediate_dim, + num_heads=num_heads, + activation=quick_gelu, + norm_first=True, + ) + # Add biases to the attention projections + self.attention = nn.MultiHeadAttention( + hidden_dim, num_heads, bias=True) + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextConfig): + super().__init__() + + self.token_embedding = nn.Embedding( + config.vocab_size, config.hidden_size) + self.position_embedding = mx.zeros( + (config.max_position_embeddings, config.hidden_size) + ) + self.layers = [ + CLIPEncoderLayer( + config.hidden_size, config.intermediate_size, config.num_attention_heads + ) + for _ in range(config.num_hidden_layers) + ] + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + + def _embed(self, x: mx.array) -> mx.array: + embeddings = self.token_embedding(x) + embeddings += self.position_embedding[: x.shape[1]] + return embeddings + + def __call__(self, x: mx.array) -> CLIPTextOutput: + B, N = x.shape + eot_tokens = mx.argmax(x, axis=-1) + x = self._embed(x) + mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) + for l in self.layers: + x = l(x, mask) + last_hidden_state = self.final_layer_norm(x) + pooler_output = last_hidden_state[mx.arange(B), eot_tokens] + + return CLIPTextOutput( + pooler_output=pooler_output, last_hidden_state=last_hidden_state + ) + + +class CLIPVisionModel(nn.Module): + """Implements the vision encoder transformer from CLIP.""" + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + + self.class_embedding = mx.zeros((config.hidden_size,)) + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + num_patches = (config.image_size // config.patch_size) ** 2 + num_positions = num_patches + 1 + self.position_embedding = mx.zeros((num_positions, config.hidden_size)) + self.pre_layernorm = nn.LayerNorm(config.hidden_size) + self.layers = [ + CLIPEncoderLayer( + config.hidden_size, config.intermediate_size, config.num_attention_heads + ) + for _ in range(config.num_hidden_layers) + ] + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def _embed(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + # Patchify using conv: + # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] + patch_embeddings = self.patch_embedding(x) + # [batch_size, num_patches, embed_dim] + patch_embeddings = mx.flatten( + patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + # Prepend embeddings + # [batch_size, 1, embed_dim] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + # [batch_size, num_patches + 1, embed_dim] + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + # Add positional encoding + embeddings += self.position_embedding + return embeddings + + def __call__(self, x: mx.array) -> CLIPVisionOutput: + x = self._embed(x) + x = self.pre_layernorm(x) + + for l in self.layers: + x = l(x, mask=None) + + # Extract token embedding + pooler_output = self.post_layernorm(x[:, 0, :]) + return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x) + + +class CLIPModel(nn.Module): + def __init__(self, config: CLIPConfig): + self.text_model = CLIPTextModel(config.text_config) + self.vision_model = CLIPVisionModel(config.vision_config) + + text_embed_dim = config.text_config.hidden_size + vision_embed_dim = config.vision_config.hidden_size + projection_dim = config.projection_dim + + self.visual_projection = nn.Linear( + vision_embed_dim, projection_dim, bias=False) + self.text_projection = nn.Linear( + text_embed_dim, projection_dim, bias=False) + self.logit_scale = mx.array(0.0) + + def get_text_features(self, x: mx.array) -> mx.array: + return self.text_projection(self.text_model(x).pooler_output) + + def get_image_features(self, x: mx.array) -> mx.array: + return self.visual_projection(self.vision_model(x).pooler_output) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + return_loss=False, + ) -> CLIPModelOutput: + if input_ids is not None: + text_model_output = self.text_model(input_ids) + text_embeds = self.text_projection(text_model_output.pooler_output) + text_embeds = text_embeds / \ + LA.norm(text_embeds, axis=-1, keepdims=True) + else: + text_embeds = None + text_model_output = None + + if pixel_values is not None: + vision_model_output = self.vision_model(pixel_values) + image_embeds = self.visual_projection( + vision_model_output.pooler_output) + image_embeds = image_embeds / \ + LA.norm(image_embeds, axis=-1, keepdims=True) + else: + image_embeds = None + vision_model_output = None + + if return_loss and (input_ids is None or pixel_values is None): + raise ValueError( + "Must provide text and image inputs to compute loss.") + + if return_loss: + logit_scale = mx.exp(self.logit_scale) + logits = (text_embeds @ image_embeds.T) * logit_scale + loss = clip_loss(logits) + else: + loss = None + + return CLIPModelOutput( + loss=loss, + text_embeds=text_embeds, + image_embeds=image_embeds, + vision_model_output=vision_model_output, + text_model_output=text_model_output, + ) + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + with open(path / "config.json", "r") as fid: + config = json.load(fid) + + text_config = config["text_config"] + text_config = CLIPTextConfig( + num_hidden_layers=text_config["num_hidden_layers"], + hidden_size=text_config["hidden_size"], + intermediate_size=text_config["intermediate_size"], + num_attention_heads=text_config["num_attention_heads"], + max_position_embeddings=text_config["max_position_embeddings"], + vocab_size=text_config["vocab_size"], + ) + + vision_config = config["vision_config"] + + vision_config = CLIPVisionConfig( + num_hidden_layers=vision_config["num_hidden_layers"], + hidden_size=vision_config["hidden_size"], + intermediate_size=vision_config["intermediate_size"], + num_attention_heads=vision_config["num_attention_heads"], + num_channels=3, + image_size=vision_config["image_size"], + patch_size=vision_config["patch_size"], + ) + + config = CLIPConfig( + text_config=text_config, + vision_config=vision_config, + projection_dim=config["projection_dim"], + ) + model = CLIPModel(config) + model.load_weights(str(path / "weights.npz")) + return model diff --git a/llava/convert.py b/llava/convert.py new file mode 100644 index 000000000..af9973850 --- /dev/null +++ b/llava/convert.py @@ -0,0 +1,87 @@ + +from safetensors.torch import load_file +from pathlib import Path +import glob +import json +import logging +import mlx.nn as nn +from huggingface_hub import snapshot_download +from typing import Dict, Tuple +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer + + +hf_repo = "llava-hf/llava-1.5-7b-hf" + + +def get_model_path(path_or_hf_repo: str) -> Path: + """ + Ensures the model is available locally. If the path does not exist locally, + it is downloaded from the Hugging Face Hub. + + Args: + path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + + Returns: + Path: The path to the model. + """ + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + return model_path + + +def load_model(model_path: Path) -> nn.Module: + """ + Load and initialize the model from a given path. + + Args: + model_path (Path): The path to load the model from. + + Returns: + nn.Module: The loaded and initialized model. + + Raises: + FileNotFoundError: If the weight files (.safetensors) are not found. + ValueError: If the model class or args class are not found or cannot be instantiated. + """ + try: + with open(model_path / "config.json", "r") as f: + config = json.load(f) + except FileNotFoundError: + logging.error(f"Config file not found in {model_path}") + raise + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {model_path}") + raise FileNotFoundError(f"No safetensors found in {model_path}") + + weights = {} + for wf in weight_files: + weights.update(load_file(wf)) + + return config, weights, weight_files + + +def fetch_from_hub( + model_path: Path, +) -> Tuple[Dict, dict, PreTrainedTokenizer]: + model_config, model_weights, model_weight_files = load_model(model_path) + + config = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained( + model_path) # TODO: should this be the processor? + + # TODO: replace outputs with the model alone once conversion is complete + return model_config, model_weights, model_weight_files, config, tokenizer diff --git a/llava/llama.py b/llava/llama.py new file mode 100644 index 000000000..7b0052355 --- /dev/null +++ b/llava/llama.py @@ -0,0 +1,404 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import glob +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from sentencepiece import SentencePieceProcessor + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + rope_theta: float + rope_traditional: bool = True + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * + args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * + args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE( + args.head_dim, traditional=args.rope_traditional, base=args.rope_theta + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape( + B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), + axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Llama(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) + for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__(self, x): + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.tok_embeddings.weight.dtype) + + x = self.tok_embeddings(x) + for l in self.layers: + x, _ = l(x, mask) + x = self.norm(x) + return self.output(x) + + def generate(self, x, temp=1.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + cache = [] + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.tok_embeddings.weight.dtype) + + # First we process the prompt x the same was as in __call__ but + # save the caches in cache + x = self.tok_embeddings(x) + for l in self.layers: + x, c = l(x, mask=mask) + # We store the per layer cache in a simple python list + cache.append(c) + x = self.norm(x) + # We only care about the last logits that generate the next token + y = self.output(x[:, -1]) + y = sample(y) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = y[:, None] + + x = self.tok_embeddings(x) + for i in range(len(cache)): + # We are overwriting the arrays in the cache list. When + # the computation will happen, MLX will be discarding the + # old cache the moment it is not needed anymore. + x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) + x = self.norm(x) + y = sample(self.output(x[:, -1])) + + yield y + + +def tic(): + return time.time() + + +def toc(msg, start): + end = time.time() + return f"[INFO] {msg}: {end - start:.3f} s" + + +def generate(args): + input("Press enter to start generation") + print("------") + print(args.prompt) + x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) + skip = 0 + prompt_processing = None + tokens = [] + start = tic() + for token in model.generate(x, args.temp): + tokens.append(token) + + if len(tokens) == 1: + # Actually perform the computation to measure the prompt processing time + mx.eval(token) + prompt_processing = toc("Prompt processing", start) + + if len(tokens) >= args.max_tokens: + break + + elif (len(tokens) % args.write_every) == 0: + # It is perfectly ok to eval things we have already eval-ed. + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + skip = len(s) + + mx.eval(tokens) + full_gen = toc("Full generation", start) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], flush=True) + print("------") + print(prompt_processing) + print(full_gen) + + +def few_shot_generate(args): + def possible_end(s): + word = "[Instruction]" + for i in range(len(word) - 1, 0, -1): + if s[-i:] == word[:i]: + return 0 + if s[-len(word):] == word: + return 1 + return -1 + + def generate(question): + x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)]) + skip = 0 + prompt_processing = None + tokens = [] + start = tic() + for token in model.generate(x, args.temp): + tokens.append(token) + + if len(tokens) == 1: + # Actually perform the computation to measure the prompt processing time + mx.eval(token) + prompt_processing = toc("Prompt processing", start) + + if len(tokens) >= args.max_tokens: + break + + mx.eval(tokens) + token_list = [t.item() for t in tokens] + s = tokenizer.decode(token_list) + + end = possible_end(s) + if end == 0: + continue + if end == 1: + skip = len(s) + break + + print(s[skip:], end="", flush=True) + skip = len(s) + if token_list[-1] == tokenizer.eos_id(): + break + + mx.eval(tokens) + full_gen = toc("Full generation", start) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + + print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) + prompt = open(args.few_shot).read().strip() + while True: + question = input("Ask a question: ") + generate(prompt.replace("{}", question)) + print() + + +def sanitize_config(config, weights): + config.pop("model_type", None) + n_heads = config["n_heads"] + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[-1] + if "rope_theta" not in config: + config["rope_theta"] = 10000 + unused = ["multiple_of", "ffn_dim_multiplier"] + for k in unused: + config.pop(k, None) + return config + + +def load_model(model_path): + model_path = Path(model_path) + + unsharded_weights_path = Path(model_path / "weights.npz") + if unsharded_weights_path.is_file(): + print("[INFO] Loading model from {}.".format(unsharded_weights_path)) + weights = mx.load(str(unsharded_weights_path)) + else: + sharded_weights_glob = str(model_path / "weights.*.npz") + weight_files = glob.glob(sharded_weights_glob) + print("[INFO] Loading model from {}.".format(sharded_weights_glob)) + + if len(weight_files) == 0: + raise FileNotFoundError( + "No weights found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + with open(model_path / "config.json", "r") as f: + config = sanitize_config(json.loads(f.read()), weights) + quantization = config.pop("quantization", None) + model = Llama(ModelArgs(**config)) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + model.update(tree_unflatten(list(weights.items()))) + tokenizer = SentencePieceProcessor( + model_file=str(model_path / "tokenizer.model")) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llama inference script") + parser.add_argument( + "--model-path", + help="Path to the model weights and tokenizer", + default="mlx_model", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model. Ignored when --few-shot is provided.", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--few-shot", + help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", + ) + parser.add_argument( + "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" + ) + parser.add_argument( + "--write-every", type=int, default=1, help="After how many tokens to detokenize" + ) + parser.add_argument( + "--temp", type=float, default=0.0, help="The sampling temperature" + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.model_path) + if args.few_shot: + few_shot_generate(args) + else: + generate(args) diff --git a/llava/llava.py b/llava/llava.py new file mode 100644 index 000000000..c96280ac2 --- /dev/null +++ b/llava/llava.py @@ -0,0 +1,88 @@ +from clip import CLIPVisionModel +from llama import Llama +from pathlib import Path +import json +import mlx.nn as nn +import mlx.core as mx +from typing import Any, Optional + + +from dataclasses import dataclass + + +@dataclass +class VisionConfig: + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_channels: int + image_size: int + patch_size: int + + +@dataclass +class LLMConfig: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + rope_theta: float + rope_traditional: bool = True + + +@dataclass +class ProjectionConfig: + in_features: int + out_features: int + + +@dataclass +class LlaVAConfig: + llm_config: Any + vision_config: VisionConfig + projection_config: ProjectionConfig + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: Any): + super().__init__() + self.linear_1 = nn.Linear(config.in_features, config.out_features) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear(config.out_features, config.out_features) + + def forward(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class LlavaModel(nn.Module): + def __init__(self, config: LlaVAConfig): + self.vision_tower = CLIPVisionModel(config=config.vision_config) + self.language_model = Llama(args=config.llm_config) + self.multi_modal_projector = LlavaMultiModalProjector( + config=config.projection_config) + + def __call__(self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None): + # TODO: add the forward pass + pass + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + with open(path / "config.json", "r") as f: + config = json.load(f) + + model = LlavaModel(config) + model.load_weights(str(path / "weights.npz")) + + return model diff --git a/llava/requirements.txt b/llava/requirements.txt new file mode 100644 index 000000000..a6f2f1ca7 --- /dev/null +++ b/llava/requirements.txt @@ -0,0 +1,6 @@ +mlx +numpy +transformers +torch +huggingface_hub +Pillow \ No newline at end of file diff --git a/llava/utils.py b/llava/utils.py new file mode 100644 index 000000000..27d8f6f58 --- /dev/null +++ b/llava/utils.py @@ -0,0 +1,87 @@ +import mlx.core as mx +import torch +from typing import Tuple + + +def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: + # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss + a = a.to(torch.float32) if dtype == "bfloat16" else a.to( + getattr(torch, dtype)) + return mx.array(a.numpy(), getattr(mx, dtype)) + + +def should_keep_weight(key: str): + return not ("position_ids" in key) + + +def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: + key = key.replace("embeddings.", "") + key = key.replace("encoder.", "") + key = key.replace("position_embedding.weight", "position_embedding") + + key = key.replace('vision_model.', '') + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + if "layer_norm1." in key: + key = key.replace("layer_norm1.", "ln1.") + if "layer_norm2." in key: + key = key.replace("layer_norm2.", "ln2.") + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + # Fix layernorm typo + if "pre_layrnorm" in key: + # Fix typo in weights :) + key = key.replace("pre_layrnorm", "pre_layernorm") + if "patch_embedding.weight" in key: + # Initially, value: [out_channels, in_channels, kH, KW]. + # We want [out_channels, kH, KW, in_channels] + value = value.permute(0, 2, 3, 1) + return (key, value) + + +def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: + key = key.replace('language_model.model.', 'language_model.') + key = key.replace('mlp.', 'feed_forward.') + key = key.replace("down_proj", "w2") + key = key.replace("up_proj", "w3") + key = key.replace("gate_proj", "w1") + key = key.replace("input_layernorm", "attention_norm") + key = key.replace("post_attention_layernorm", "ffn_norm") + key = key.replace("lm_head", "output") + + key = key.replace("embed_tokens", "tok_embeddings") + key = key.replace("self_attn", "attention") + + key = key.replace("q_proj", "wq") + key = key.replace("k_proj", "wk") + key = key.replace("v_proj", "wv") + key = key.replace("o_proj", "wo") + + return (key, value) + + +def map_multi_modal_projector_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: + return (key, value) + + +def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: + + if 'vision_tower' in key: + key, value = map_vision_tower_weights(key, value) + elif 'language_model' in key: + key, value = map_language_model_weights(key, value) + elif 'multi_modal_projector' in key: + key, value = map_multi_modal_projector_weights(key, value) + + return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) From 0e2a05449ed08cc68d56647ffb80d69e23f119e7 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 19 Feb 2024 12:16:45 -0500 Subject: [PATCH 02/34] add: weights comparision --- llava/Local LLava.ipynb | 118 ++++++++++++++++++++++++++++++++-------- 1 file changed, 94 insertions(+), 24 deletions(-) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index bc8252628..3554bbb8a 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -32,14 +32,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 207126.12it/s]\n", + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 206277.25it/s]\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -99,8 +99,8 @@ " head_dim=4096,\n", " hidden_dim=11008,\n", " norm_eps=1e-5,\n", - " n_heads=32, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L14. But only works with 1. Please see llama file for how heads are split. Is this wrong?\n", - " n_kv_heads=32, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L16\n", + " n_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L14. But only works with 1. Please see llama file for how heads are split. Is this wrong?\n", + " n_kv_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L16\n", " vocab_size=32064,\n", " rope_theta=0,\n", " rope_traditional=False\n", @@ -121,39 +121,107 @@ ")\n", "\n", "\n", - "model = LlavaModel(llava_mlx_config)\n", + "mlx_model = LlavaModel(llava_mlx_config)\n", "\n" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "mlx_model.load_weights('mlx_model/weights.npz')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: load images, and test generate " + ] + }, + { + "cell_type": "code", + "execution_count": 19, "metadata": {}, "outputs": [ { - "ename": "ValueError", - "evalue": "Expected shape (131072, 4096) but received shape (4096, 4096) for parameter language_model.layers.0.attention.wq.weight", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_weights\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmlx_model/weights.npz\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/mlx/lib/python3.10/site-packages/mlx/nn/layers/base.py:176\u001b[0m, in \u001b[0;36mModule.load_weights\u001b[0;34m(self, file_or_weights, strict)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 172\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected mx.array but received \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(v_new)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 174\u001b[0m )\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v_new\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m v\u001b[38;5;241m.\u001b[39mshape:\n\u001b[0;32m--> 176\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 177\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected shape \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m but received \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m shape \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv_new\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 179\u001b[0m )\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate(tree_unflatten(weights))\n", - "\u001b[0;31mValueError\u001b[0m: Expected shape (131072, 4096) but received shape (4096, 4096) for parameter language_model.layers.0.attention.wq.weight" + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 3/3 [00:23<00:00, 8.00s/it]\n" ] } ], "source": [ - "model.load_weights('mlx_model/weights.npz')\n" + "# TODO: compare with hf version's model weights as well \n", + "\n", + "# Load model directly\n", + "from transformers import AutoProcessor, AutoModelForPreTraining\n", + "\n", + "model = AutoModelForPreTraining.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.00692749, -0.0147705, -0.00254822, ..., 0.00500488, 0.00238037, -0.0027771],\n", + " [0.0155029, -0.00343323, 0.00121307, ..., -0.00964355, -0.0110474, 0.00744629],\n", + " [-0.0157471, 0.0144043, 0.000104904, ..., 0.00619507, 0.0189209, -0.00415039],\n", + " ...,\n", + " [1.54972e-06, 0.00866699, 0.000881195, ..., 0.00946045, -0.0301514, 0.0107422],\n", + " [0.0253906, 0.00994873, 0.00454712, ..., -0.0319824, -0.0148926, -0.0130005],\n", + " [-0.0108643, -0.00534058, 0.00102234, ..., 0.0164795, 0.0150146, -0.00811768]], dtype=float16)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlx_model.language_model.layers[0].attention.wq.weight" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Parameter containing:\n", + "tensor([[-6.9275e-03, -1.4771e-02, -2.5482e-03, ..., 5.0049e-03,\n", + " 2.3804e-03, -2.7771e-03],\n", + " [ 1.5503e-02, -3.4332e-03, 1.2131e-03, ..., -9.6436e-03,\n", + " -1.1047e-02, 7.4463e-03],\n", + " [-1.5747e-02, 1.4404e-02, 1.0490e-04, ..., 6.1951e-03,\n", + " 1.8921e-02, -4.1504e-03],\n", + " ...,\n", + " [ 1.5497e-06, 8.6670e-03, 8.8120e-04, ..., 9.4604e-03,\n", + " -3.0151e-02, 1.0742e-02],\n", + " [ 2.5391e-02, 9.9487e-03, 4.5471e-03, ..., -3.1982e-02,\n", + " -1.4893e-02, -1.3000e-02],\n", + " [-1.0864e-02, -5.3406e-03, 1.0223e-03, ..., 1.6479e-02,\n", + " 1.5015e-02, -8.1177e-03]], requires_grad=True)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# TODO: load images, and test generate " + "model.language_model.model.layers[0].self_attn.q_proj.weight" ] }, { @@ -161,7 +229,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# They seem to be the same!" + ] } ], "metadata": { From 6e4a7eec0d0698f4069d18c9fc5dbb63ddfe818c Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 19 Feb 2024 14:31:28 -0500 Subject: [PATCH 03/34] add forward pass skeleton --- llava/Local LLava.ipynb | 56 +++++++++++++++++++++++++++++++++-------- llava/llava.py | 14 ++++++++++- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index 3554bbb8a..3e94828d7 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -32,14 +32,16 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 206277.25it/s]\n", + "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 181702.70it/s]\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } @@ -55,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -86,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -127,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -145,14 +147,15 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 3/3 [00:23<00:00, 8.00s/it]\n" + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00, 2.48s/it]\n" ] } ], @@ -161,7 +164,7 @@ "\n", "# Load model directly\n", "from transformers import AutoProcessor, AutoModelForPreTraining\n", - "\n", + "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n", "model = AutoModelForPreTraining.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")" ] }, @@ -232,6 +235,37 @@ "source": [ "# They seem to be the same!" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from PIL import Image\n", + "\n", + "image = Image.open(requests.get(\"https://llava-vl.github.io/static/images/view.jpg\", stream=True).raw)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = ['USER: What are the things I should think aboutwhen I visit this place? ASSISTANT:'\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = processor(prompts, images=[image], padding=True, return_tensors=\"pt\")" + ] } ], "metadata": { diff --git a/llava/llava.py b/llava/llava.py index c96280ac2..63c70db59 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -71,8 +71,20 @@ def __init__(self, config: LlaVAConfig): def __call__(self, input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None): + pixel_values: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None,): # TODO: add the forward pass + + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values) + + image_features = self.multi_modal_projector( + image_outputs) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels) + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 pass @staticmethod From ed9d376e14986e925c5181f6d3e2551b702d9edb Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Wed, 21 Feb 2024 19:21:50 -0500 Subject: [PATCH 04/34] update: now imports weights correctly --- llava/.gitignore | 1 + llava/Local LLava.ipynb | 233 +++++++++++++++++---- llava/base.py | 0 llava/llama.py | 446 ++++++++++++---------------------------- llava/llava.py | 25 +-- llava/utils.py | 32 +-- 6 files changed, 358 insertions(+), 379 deletions(-) create mode 100644 llava/.gitignore create mode 100644 llava/base.py diff --git a/llava/.gitignore b/llava/.gitignore new file mode 100644 index 000000000..857540df8 --- /dev/null +++ b/llava/.gitignore @@ -0,0 +1 @@ +**mlx_model \ No newline at end of file diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index 3e94828d7..0622420da 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -41,7 +48,7 @@ "text": [ "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 181702.70it/s]\n", + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 202950.19it/s]\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } @@ -57,38 +64,37 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO] Converting\n", - "[INFO] Saving\n" - ] - } - ], + "outputs": [], "source": [ "from utils import map_weights, should_keep_weight\n", + "do_convert = False\n", + "if do_convert:\n", "\n", - "\n", - "print(\"[INFO] Converting\")\n", - "mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n", - "mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n", - "print(\"[INFO] Saving\")\n", - "mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n", - "for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n", - " if fn in os.listdir(model_path):\n", - " shutil.copyfile(\n", - " str(model_path / f\"{fn}\"),\n", - " str(mlx_path / f\"{fn}\"),\n", - " )\n" + " print(\"[INFO] Converting\")\n", + " mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n", + " mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n", + " print(\"[INFO] Saving\")\n", + " mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n", + " for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n", + " if fn in os.listdir(model_path):\n", + " shutil.copyfile(\n", + " str(model_path / f\"{fn}\"),\n", + " str(mlx_path / f\"{fn}\"),\n", + " )\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -96,17 +102,18 @@ "\n", "llava_mlx_config = LlaVAConfig(\n", " llm_config=LLMConfig(\n", - " dim=4096,\n", - " n_layers=32,\n", - " head_dim=4096,\n", - " hidden_dim=11008,\n", - " norm_eps=1e-5,\n", - " n_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L14. But only works with 1. Please see llama file for how heads are split. Is this wrong?\n", - " n_kv_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L16\n", + " model_type='vicuna',\n", + " hidden_size=4096,\n", + " num_hidden_layers=32,\n", + " intermediate_size=11008,\n", + " num_attention_heads=32,\n", + " rms_norm_eps=1e-5,\n", " vocab_size=32064,\n", + " num_key_value_heads=32,\n", " rope_theta=0,\n", - " rope_traditional=False\n", - " ),\n", + " rope_traditional=False,\n", + " rope_scaling=None\n", + " ),\n", " vision_config=VisionConfig(\n", " num_hidden_layers=24,\n", " hidden_size=1024,\n", @@ -129,9 +136,42 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "2.6875" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "11008 / 4096" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Received parameters not in model: language_model.layers.20.feed_forward.w3.weight language_model.layers.18.feed_forward.w2.weight language_model.layers.20.feed_forward.w1.weight language_model.layers.0.feed_forward.w1.weight language_model.layers.23.attention.wq.weight language_model.layers.23.ffn_norm.weight language_model.layers.16.feed_forward.w1.weight language_model.layers.31.attention_norm.weight language_model.output.weight language_model.layers.27.attention_norm.weight language_model.layers.25.attention_norm.weight language_model.layers.28.attention_norm.weight language_model.layers.20.feed_forward.w2.weight language_model.layers.17.attention.wo.weight language_model.layers.1.attention.wq.weight language_model.layers.27.feed_forward.w3.weight language_model.layers.19.feed_forward.w1.weight language_model.layers.14.attention_norm.weight language_model.layers.21.feed_forward.w2.weight language_model.layers.16.attention.wo.weight language_model.layers.22.attention_norm.weight language_model.layers.4.attention.wk.weight language_model.layers.13.feed_forward.w1.weight language_model.layers.30.attention.wv.weight language_model.layers.5.feed_forward.w3.weight language_model.layers.20.attention_norm.weight language_model.layers.13.feed_forward.w2.weight language_model.layers.22.feed_forward.w2.weight language_model.layers.15.attention.wo.weight language_model.layers.26.attention.wo.weight language_model.layers.5.feed_forward.w1.weight language_model.layers.16.attention_norm.weight language_model.layers.4.attention.wq.weight language_model.layers.9.feed_forward.w1.weight language_model.layers.20.attention.wq.weight language_model.layers.9.attention.wv.weight language_model.layers.10.ffn_norm.weight language_model.layers.8.attention.wo.weight language_model.layers.3.attention.wv.weight language_model.layers.0.ffn_norm.weight language_model.layers.4.feed_forward.w3.weight language_model.layers.2.attention.wv.weight language_model.layers.7.attention.wv.weight language_model.layers.24.attention.wq.weight language_model.layers.11.feed_forward.w2.weight language_model.layers.0.attention.wo.weight language_model.layers.7.feed_forward.w3.weight language_model.layers.17.feed_forward.w1.weight language_model.layers.31.attention.wo.weight language_model.layers.26.attention.wk.weight language_model.layers.0.feed_forward.w3.weight language_model.layers.2.ffn_norm.weight language_model.layers.13.attention_norm.weight language_model.layers.19.attention.wk.weight language_model.layers.18.attention.wq.weight language_model.layers.10.attention.wo.weight language_model.layers.30.attention.wq.weight language_model.layers.5.feed_forward.w2.weight language_model.layers.5.attention.wv.weight language_model.layers.25.attention.wq.weight language_model.layers.3.feed_forward.w3.weight language_model.layers.9.attention.wo.weight language_model.layers.29.feed_forward.w1.weight language_model.layers.2.feed_forward.w3.weight language_model.layers.0.attention.wk.weight language_model.layers.11.attention.wv.weight language_model.layers.20.attention.wo.weight language_model.layers.16.ffn_norm.weight language_model.layers.2.feed_forward.w2.weight language_model.layers.27.attention.wk.weight language_model.tok_embeddings.weight language_model.layers.14.ffn_norm.weight language_model.layers.12.ffn_norm.weight language_model.layers.22.attention.wo.weight language_model.layers.12.attention.wq.weight language_model.layers.19.attention.wq.weight language_model.layers.11.attention.wq.weight language_model.layers.6.attention.wv.weight language_model.layers.26.feed_forward.w3.weight language_model.layers.26.feed_forward.w2.weight language_model.layers.17.attention.wq.weight language_model.layers.18.feed_forward.w3.weight language_model.layers.29.attention.wk.weight language_model.layers.29.feed_forward.w2.weight language_model.layers.8.attention.wq.weight language_model.layers.2.attention_norm.weight language_model.layers.5.attention.wo.weight language_model.layers.23.feed_forward.w1.weight language_model.layers.30.feed_forward.w1.weight language_model.layers.2.attention.wo.weight language_model.layers.18.attention.wk.weight language_model.layers.13.attention.wo.weight language_model.layers.3.ffn_norm.weight language_model.layers.23.attention_norm.weight language_model.layers.10.feed_forward.w2.weight language_model.layers.1.ffn_norm.weight language_model.layers.21.ffn_norm.weight language_model.layers.30.attention.wo.weight language_model.layers.11.attention.wk.weight language_model.layers.7.attention.wo.weight language_model.layers.17.feed_forward.w3.weight language_model.layers.13.ffn_norm.weight language_model.layers.3.feed_forward.w1.weight language_model.layers.18.attention.wo.weight language_model.layers.22.feed_forward.w1.weight language_model.layers.15.attention.wk.weight language_model.layers.15.attention.wv.weight language_model.layers.28.feed_forward.w2.weight language_model.layers.21.feed_forward.w1.weight language_model.layers.12.feed_forward.w2.weight language_model.layers.23.feed_forward.w3.weight language_model.layers.19.ffn_norm.weight language_model.layers.18.attention_norm.weight language_model.layers.22.feed_forward.w3.weight language_model.layers.14.attention.wo.weight language_model.layers.9.ffn_norm.weight language_model.layers.13.attention.wk.weight language_model.layers.28.attention.wo.weight language_model.layers.26.attention.wq.weight language_model.layers.24.ffn_norm.weight language_model.layers.23.attention.wo.weight language_model.layers.10.attention_norm.weight language_model.layers.16.feed_forward.w2.weight language_model.layers.19.feed_forward.w2.weight language_model.layers.23.attention.wk.weight language_model.layers.2.feed_forward.w1.weight language_model.layers.11.feed_forward.w1.weight language_model.layers.4.feed_forward.w2.weight language_model.layers.23.attention.wv.weight language_model.layers.27.feed_forward.w1.weight language_model.layers.17.feed_forward.w2.weight language_model.layers.12.attention_norm.weight language_model.layers.30.feed_forward.w2.weight language_model.layers.15.ffn_norm.weight language_model.layers.12.feed_forward.w1.weight language_model.layers.28.attention.wq.weight language_model.layers.10.attention.wq.weight language_model.layers.4.ffn_norm.weight language_model.layers.14.feed_forward.w1.weight language_model.layers.3.feed_forward.w2.weight language_model.layers.12.feed_forward.w3.weight language_model.layers.21.attention.wq.weight language_model.layers.10.attention.wv.weight language_model.layers.21.feed_forward.w3.weight language_model.layers.6.feed_forward.w2.weight language_model.layers.20.ffn_norm.weight language_model.layers.25.attention.wk.weight language_model.layers.17.attention.wk.weight language_model.layers.26.attention_norm.weight language_model.layers.25.feed_forward.w2.weight language_model.layers.1.attention_norm.weight language_model.layers.26.attention.wv.weight language_model.layers.19.attention.wo.weight language_model.layers.14.feed_forward.w3.weight language_model.layers.14.attention.wv.weight language_model.layers.29.ffn_norm.weight language_model.layers.14.feed_forward.w2.weight language_model.layers.1.attention.wk.weight language_model.layers.4.attention.wv.weight language_model.layers.22.attention.wq.weight language_model.layers.3.attention.wq.weight language_model.layers.16.attention.wv.weight language_model.layers.21.attention.wo.weight language_model.layers.26.ffn_norm.weight language_model.layers.29.attention.wq.weight language_model.layers.7.attention.wq.weight language_model.layers.21.attention_norm.weight language_model.layers.24.attention.wo.weight language_model.layers.5.attention_norm.weight language_model.layers.18.feed_forward.w1.weight language_model.layers.26.feed_forward.w1.weight language_model.layers.31.attention.wv.weight language_model.layers.25.feed_forward.w1.weight language_model.layers.27.ffn_norm.weight language_model.layers.6.feed_forward.w1.weight language_model.layers.28.feed_forward.w1.weight language_model.layers.1.feed_forward.w3.weight language_model.layers.8.feed_forward.w2.weight language_model.layers.20.attention.wk.weight language_model.layers.2.attention.wq.weight language_model.layers.4.feed_forward.w1.weight language_model.layers.9.attention.wq.weight language_model.layers.15.feed_forward.w1.weight language_model.layers.7.ffn_norm.weight language_model.layers.0.feed_forward.w2.weight language_model.layers.30.attention_norm.weight language_model.layers.13.attention.wv.weight language_model.layers.10.feed_forward.w1.weight language_model.layers.5.attention.wq.weight language_model.layers.16.feed_forward.w3.weight language_model.layers.28.ffn_norm.weight language_model.layers.31.feed_forward.w1.weight language_model.layers.12.attention.wo.weight language_model.layers.27.attention.wo.weight language_model.layers.15.feed_forward.w3.weight language_model.layers.29.attention.wo.weight language_model.layers.27.attention.wv.weight language_model.layers.14.attention.wq.weight language_model.layers.5.attention.wk.weight language_model.layers.1.feed_forward.w1.weight language_model.layers.20.attention.wv.weight language_model.layers.23.feed_forward.w2.weight language_model.layers.8.attention.wk.weight language_model.layers.5.ffn_norm.weight language_model.layers.21.attention.wv.weight language_model.layers.29.attention_norm.weight language_model.layers.10.feed_forward.w3.weight language_model.layers.1.feed_forward.w2.weight language_model.layers.24.feed_forward.w3.weight language_model.layers.11.ffn_norm.weight language_model.layers.9.attention_norm.weight language_model.layers.4.attention.wo.weight language_model.layers.25.attention.wo.weight language_model.layers.7.feed_forward.w2.weight language_model.layers.9.feed_forward.w2.weight language_model.layers.14.attention.wk.weight language_model.layers.27.feed_forward.w2.weight language_model.layers.13.attention.wq.weight language_model.layers.15.attention_norm.weight language_model.layers.28.attention.wv.weight language_model.layers.0.attention_norm.weight language_model.layers.0.attention.wv.weight language_model.layers.7.attention.wk.weight language_model.layers.29.feed_forward.w3.weight language_model.layers.3.attention.wk.weight language_model.layers.28.feed_forward.w3.weight language_model.layers.22.attention.wv.weight language_model.layers.22.attention.wk.weight language_model.layers.6.attention.wq.weight language_model.layers.1.attention.wo.weight language_model.layers.18.attention.wv.weight language_model.layers.8.attention.wv.weight language_model.layers.6.ffn_norm.weight language_model.layers.25.ffn_norm.weight language_model.layers.8.attention_norm.weight language_model.layers.6.attention.wk.weight language_model.layers.29.attention.wv.weight language_model.layers.19.attention_norm.weight language_model.layers.19.attention.wv.weight language_model.layers.6.attention.wo.weight language_model.layers.12.attention.wk.weight language_model.layers.9.feed_forward.w3.weight language_model.layers.8.feed_forward.w1.weight language_model.layers.10.attention.wk.weight language_model.layers.17.ffn_norm.weight language_model.layers.21.attention.wk.weight language_model.layers.15.attention.wq.weight language_model.layers.11.attention_norm.weight language_model.layers.24.attention.wk.weight language_model.layers.31.feed_forward.w2.weight language_model.layers.18.ffn_norm.weight language_model.layers.30.feed_forward.w3.weight language_model.layers.22.ffn_norm.weight language_model.layers.28.attention.wk.weight language_model.layers.9.attention.wk.weight language_model.layers.24.feed_forward.w2.weight language_model.layers.17.attention_norm.weight language_model.layers.17.attention.wv.weight language_model.layers.1.attention.wv.weight language_model.layers.31.ffn_norm.weight language_model.layers.31.attention.wk.weight language_model.layers.24.feed_forward.w1.weight language_model.layers.8.feed_forward.w3.weight language_model.layers.25.attention.wv.weight language_model.layers.7.feed_forward.w1.weight language_model.layers.31.attention.wq.weight language_model.layers.15.feed_forward.w2.weight language_model.layers.30.ffn_norm.weight language_model.layers.0.attention.wq.weight language_model.layers.31.feed_forward.w3.weight language_model.layers.13.feed_forward.w3.weight language_model.layers.19.feed_forward.w3.weight language_model.layers.6.attention_norm.weight language_model.layers.4.attention_norm.weight language_model.layers.12.attention.wv.weight language_model.layers.8.ffn_norm.weight language_model.layers.30.attention.wk.weight language_model.layers.3.attention.wo.weight language_model.layers.16.attention.wq.weight language_model.layers.11.feed_forward.w3.weight language_model.layers.25.feed_forward.w3.weight language_model.layers.3.attention_norm.weight language_model.layers.2.attention.wk.weight language_model.layers.16.attention.wk.weight language_model.layers.7.attention_norm.weight language_model.layers.27.attention.wq.weight language_model.layers.6.feed_forward.w3.weight language_model.layers.24.attention.wv.weight language_model.layers.11.attention.wo.weight language_model.layers.24.attention_norm.weight.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmlx_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_weights\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmlx_model/weights.npz\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/mlx/lib/python3.10/site-packages/mlx/nn/layers/base.py:164\u001b[0m, in \u001b[0;36mModule.load_weights\u001b[0;34m(self, file_or_weights, strict)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m extras \u001b[38;5;241m:=\u001b[39m (new_weights\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;241m-\u001b[39m curr_weights\u001b[38;5;241m.\u001b[39mkeys()):\n\u001b[1;32m 163\u001b[0m extras \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(extras)\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReceived parameters not in model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mextras\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m missing \u001b[38;5;241m:=\u001b[39m (curr_weights\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;241m-\u001b[39m new_weights\u001b[38;5;241m.\u001b[39mkeys()):\n\u001b[1;32m 166\u001b[0m missing \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing)\n", + "\u001b[0;31mValueError\u001b[0m: Received parameters not in model: language_model.layers.20.feed_forward.w3.weight language_model.layers.18.feed_forward.w2.weight language_model.layers.20.feed_forward.w1.weight language_model.layers.0.feed_forward.w1.weight language_model.layers.23.attention.wq.weight language_model.layers.23.ffn_norm.weight language_model.layers.16.feed_forward.w1.weight language_model.layers.31.attention_norm.weight language_model.output.weight language_model.layers.27.attention_norm.weight language_model.layers.25.attention_norm.weight language_model.layers.28.attention_norm.weight language_model.layers.20.feed_forward.w2.weight language_model.layers.17.attention.wo.weight language_model.layers.1.attention.wq.weight language_model.layers.27.feed_forward.w3.weight language_model.layers.19.feed_forward.w1.weight language_model.layers.14.attention_norm.weight language_model.layers.21.feed_forward.w2.weight language_model.layers.16.attention.wo.weight language_model.layers.22.attention_norm.weight language_model.layers.4.attention.wk.weight language_model.layers.13.feed_forward.w1.weight language_model.layers.30.attention.wv.weight language_model.layers.5.feed_forward.w3.weight language_model.layers.20.attention_norm.weight language_model.layers.13.feed_forward.w2.weight language_model.layers.22.feed_forward.w2.weight language_model.layers.15.attention.wo.weight language_model.layers.26.attention.wo.weight language_model.layers.5.feed_forward.w1.weight language_model.layers.16.attention_norm.weight language_model.layers.4.attention.wq.weight language_model.layers.9.feed_forward.w1.weight language_model.layers.20.attention.wq.weight language_model.layers.9.attention.wv.weight language_model.layers.10.ffn_norm.weight language_model.layers.8.attention.wo.weight language_model.layers.3.attention.wv.weight language_model.layers.0.ffn_norm.weight language_model.layers.4.feed_forward.w3.weight language_model.layers.2.attention.wv.weight language_model.layers.7.attention.wv.weight language_model.layers.24.attention.wq.weight language_model.layers.11.feed_forward.w2.weight language_model.layers.0.attention.wo.weight language_model.layers.7.feed_forward.w3.weight language_model.layers.17.feed_forward.w1.weight language_model.layers.31.attention.wo.weight language_model.layers.26.attention.wk.weight language_model.layers.0.feed_forward.w3.weight language_model.layers.2.ffn_norm.weight language_model.layers.13.attention_norm.weight language_model.layers.19.attention.wk.weight language_model.layers.18.attention.wq.weight language_model.layers.10.attention.wo.weight language_model.layers.30.attention.wq.weight language_model.layers.5.feed_forward.w2.weight language_model.layers.5.attention.wv.weight language_model.layers.25.attention.wq.weight language_model.layers.3.feed_forward.w3.weight language_model.layers.9.attention.wo.weight language_model.layers.29.feed_forward.w1.weight language_model.layers.2.feed_forward.w3.weight language_model.layers.0.attention.wk.weight language_model.layers.11.attention.wv.weight language_model.layers.20.attention.wo.weight language_model.layers.16.ffn_norm.weight language_model.layers.2.feed_forward.w2.weight language_model.layers.27.attention.wk.weight language_model.tok_embeddings.weight language_model.layers.14.ffn_norm.weight language_model.layers.12.ffn_norm.weight language_model.layers.22.attention.wo.weight language_model.layers.12.attention.wq.weight language_model.layers.19.attention.wq.weight language_model.layers.11.attention.wq.weight language_model.layers.6.attention.wv.weight language_model.layers.26.feed_forward.w3.weight language_model.layers.26.feed_forward.w2.weight language_model.layers.17.attention.wq.weight language_model.layers.18.feed_forward.w3.weight language_model.layers.29.attention.wk.weight language_model.layers.29.feed_forward.w2.weight language_model.layers.8.attention.wq.weight language_model.layers.2.attention_norm.weight language_model.layers.5.attention.wo.weight language_model.layers.23.feed_forward.w1.weight language_model.layers.30.feed_forward.w1.weight language_model.layers.2.attention.wo.weight language_model.layers.18.attention.wk.weight language_model.layers.13.attention.wo.weight language_model.layers.3.ffn_norm.weight language_model.layers.23.attention_norm.weight language_model.layers.10.feed_forward.w2.weight language_model.layers.1.ffn_norm.weight language_model.layers.21.ffn_norm.weight language_model.layers.30.attention.wo.weight language_model.layers.11.attention.wk.weight language_model.layers.7.attention.wo.weight language_model.layers.17.feed_forward.w3.weight language_model.layers.13.ffn_norm.weight language_model.layers.3.feed_forward.w1.weight language_model.layers.18.attention.wo.weight language_model.layers.22.feed_forward.w1.weight language_model.layers.15.attention.wk.weight language_model.layers.15.attention.wv.weight language_model.layers.28.feed_forward.w2.weight language_model.layers.21.feed_forward.w1.weight language_model.layers.12.feed_forward.w2.weight language_model.layers.23.feed_forward.w3.weight language_model.layers.19.ffn_norm.weight language_model.layers.18.attention_norm.weight language_model.layers.22.feed_forward.w3.weight language_model.layers.14.attention.wo.weight language_model.layers.9.ffn_norm.weight language_model.layers.13.attention.wk.weight language_model.layers.28.attention.wo.weight language_model.layers.26.attention.wq.weight language_model.layers.24.ffn_norm.weight language_model.layers.23.attention.wo.weight language_model.layers.10.attention_norm.weight language_model.layers.16.feed_forward.w2.weight language_model.layers.19.feed_forward.w2.weight language_model.layers.23.attention.wk.weight language_model.layers.2.feed_forward.w1.weight language_model.layers.11.feed_forward.w1.weight language_model.layers.4.feed_forward.w2.weight language_model.layers.23.attention.wv.weight language_model.layers.27.feed_forward.w1.weight language_model.layers.17.feed_forward.w2.weight language_model.layers.12.attention_norm.weight language_model.layers.30.feed_forward.w2.weight language_model.layers.15.ffn_norm.weight language_model.layers.12.feed_forward.w1.weight language_model.layers.28.attention.wq.weight language_model.layers.10.attention.wq.weight language_model.layers.4.ffn_norm.weight language_model.layers.14.feed_forward.w1.weight language_model.layers.3.feed_forward.w2.weight language_model.layers.12.feed_forward.w3.weight language_model.layers.21.attention.wq.weight language_model.layers.10.attention.wv.weight language_model.layers.21.feed_forward.w3.weight language_model.layers.6.feed_forward.w2.weight language_model.layers.20.ffn_norm.weight language_model.layers.25.attention.wk.weight language_model.layers.17.attention.wk.weight language_model.layers.26.attention_norm.weight language_model.layers.25.feed_forward.w2.weight language_model.layers.1.attention_norm.weight language_model.layers.26.attention.wv.weight language_model.layers.19.attention.wo.weight language_model.layers.14.feed_forward.w3.weight language_model.layers.14.attention.wv.weight language_model.layers.29.ffn_norm.weight language_model.layers.14.feed_forward.w2.weight language_model.layers.1.attention.wk.weight language_model.layers.4.attention.wv.weight language_model.layers.22.attention.wq.weight language_model.layers.3.attention.wq.weight language_model.layers.16.attention.wv.weight language_model.layers.21.attention.wo.weight language_model.layers.26.ffn_norm.weight language_model.layers.29.attention.wq.weight language_model.layers.7.attention.wq.weight language_model.layers.21.attention_norm.weight language_model.layers.24.attention.wo.weight language_model.layers.5.attention_norm.weight language_model.layers.18.feed_forward.w1.weight language_model.layers.26.feed_forward.w1.weight language_model.layers.31.attention.wv.weight language_model.layers.25.feed_forward.w1.weight language_model.layers.27.ffn_norm.weight language_model.layers.6.feed_forward.w1.weight language_model.layers.28.feed_forward.w1.weight language_model.layers.1.feed_forward.w3.weight language_model.layers.8.feed_forward.w2.weight language_model.layers.20.attention.wk.weight language_model.layers.2.attention.wq.weight language_model.layers.4.feed_forward.w1.weight language_model.layers.9.attention.wq.weight language_model.layers.15.feed_forward.w1.weight language_model.layers.7.ffn_norm.weight language_model.layers.0.feed_forward.w2.weight language_model.layers.30.attention_norm.weight language_model.layers.13.attention.wv.weight language_model.layers.10.feed_forward.w1.weight language_model.layers.5.attention.wq.weight language_model.layers.16.feed_forward.w3.weight language_model.layers.28.ffn_norm.weight language_model.layers.31.feed_forward.w1.weight language_model.layers.12.attention.wo.weight language_model.layers.27.attention.wo.weight language_model.layers.15.feed_forward.w3.weight language_model.layers.29.attention.wo.weight language_model.layers.27.attention.wv.weight language_model.layers.14.attention.wq.weight language_model.layers.5.attention.wk.weight language_model.layers.1.feed_forward.w1.weight language_model.layers.20.attention.wv.weight language_model.layers.23.feed_forward.w2.weight language_model.layers.8.attention.wk.weight language_model.layers.5.ffn_norm.weight language_model.layers.21.attention.wv.weight language_model.layers.29.attention_norm.weight language_model.layers.10.feed_forward.w3.weight language_model.layers.1.feed_forward.w2.weight language_model.layers.24.feed_forward.w3.weight language_model.layers.11.ffn_norm.weight language_model.layers.9.attention_norm.weight language_model.layers.4.attention.wo.weight language_model.layers.25.attention.wo.weight language_model.layers.7.feed_forward.w2.weight language_model.layers.9.feed_forward.w2.weight language_model.layers.14.attention.wk.weight language_model.layers.27.feed_forward.w2.weight language_model.layers.13.attention.wq.weight language_model.layers.15.attention_norm.weight language_model.layers.28.attention.wv.weight language_model.layers.0.attention_norm.weight language_model.layers.0.attention.wv.weight language_model.layers.7.attention.wk.weight language_model.layers.29.feed_forward.w3.weight language_model.layers.3.attention.wk.weight language_model.layers.28.feed_forward.w3.weight language_model.layers.22.attention.wv.weight language_model.layers.22.attention.wk.weight language_model.layers.6.attention.wq.weight language_model.layers.1.attention.wo.weight language_model.layers.18.attention.wv.weight language_model.layers.8.attention.wv.weight language_model.layers.6.ffn_norm.weight language_model.layers.25.ffn_norm.weight language_model.layers.8.attention_norm.weight language_model.layers.6.attention.wk.weight language_model.layers.29.attention.wv.weight language_model.layers.19.attention_norm.weight language_model.layers.19.attention.wv.weight language_model.layers.6.attention.wo.weight language_model.layers.12.attention.wk.weight language_model.layers.9.feed_forward.w3.weight language_model.layers.8.feed_forward.w1.weight language_model.layers.10.attention.wk.weight language_model.layers.17.ffn_norm.weight language_model.layers.21.attention.wk.weight language_model.layers.15.attention.wq.weight language_model.layers.11.attention_norm.weight language_model.layers.24.attention.wk.weight language_model.layers.31.feed_forward.w2.weight language_model.layers.18.ffn_norm.weight language_model.layers.30.feed_forward.w3.weight language_model.layers.22.ffn_norm.weight language_model.layers.28.attention.wk.weight language_model.layers.9.attention.wk.weight language_model.layers.24.feed_forward.w2.weight language_model.layers.17.attention_norm.weight language_model.layers.17.attention.wv.weight language_model.layers.1.attention.wv.weight language_model.layers.31.ffn_norm.weight language_model.layers.31.attention.wk.weight language_model.layers.24.feed_forward.w1.weight language_model.layers.8.feed_forward.w3.weight language_model.layers.25.attention.wv.weight language_model.layers.7.feed_forward.w1.weight language_model.layers.31.attention.wq.weight language_model.layers.15.feed_forward.w2.weight language_model.layers.30.ffn_norm.weight language_model.layers.0.attention.wq.weight language_model.layers.31.feed_forward.w3.weight language_model.layers.13.feed_forward.w3.weight language_model.layers.19.feed_forward.w3.weight language_model.layers.6.attention_norm.weight language_model.layers.4.attention_norm.weight language_model.layers.12.attention.wv.weight language_model.layers.8.ffn_norm.weight language_model.layers.30.attention.wk.weight language_model.layers.3.attention.wo.weight language_model.layers.16.attention.wq.weight language_model.layers.11.feed_forward.w3.weight language_model.layers.25.feed_forward.w3.weight language_model.layers.3.attention_norm.weight language_model.layers.2.attention.wk.weight language_model.layers.16.attention.wk.weight language_model.layers.7.attention_norm.weight language_model.layers.27.attention.wq.weight language_model.layers.6.feed_forward.w3.weight language_model.layers.24.attention.wv.weight language_model.layers.11.attention.wo.weight language_model.layers.24.attention_norm.weight." + ] + } + ], "source": [ "mlx_model.load_weights('mlx_model/weights.npz')\n" ] @@ -250,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -260,12 +300,123 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "inputs = processor(prompts, images=[image], padding=True, return_tensors=\"pt\")" ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[ 1, 3148, 1001, 29901, 29871, 32000, 29871, 1724, 526, 278,\n", + " 2712, 306, 881, 1348, 1048, 8256, 306, 6493, 445, 2058,\n", + " 29973, 319, 1799, 9047, 13566, 29901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1]]), 'pixel_values': tensor([[[[ 1.2734, 1.2734, 1.2734, ..., 1.1274, 1.1274, 1.0982],\n", + " [ 1.2734, 1.2734, 1.2880, ..., 1.1274, 1.1274, 1.1128],\n", + " [ 1.2880, 1.2880, 1.2880, ..., 1.1274, 1.1274, 1.1274],\n", + " ...,\n", + " [-0.9456, -0.9164, -0.9164, ..., -1.0769, -1.0769, -1.0769],\n", + " [-0.9602, -0.9310, -0.9018, ..., -1.0915, -1.0915, -1.0915],\n", + " [-0.9602, -0.9748, -0.2448, ..., -1.1061, -1.1061, -1.1207]],\n", + "\n", + " [[ 1.6397, 1.6397, 1.6397, ..., 1.5196, 1.5196, 1.5196],\n", + " [ 1.6397, 1.6397, 1.6547, ..., 1.5196, 1.5196, 1.5196],\n", + " [ 1.6547, 1.6547, 1.6547, ..., 1.5196, 1.5196, 1.5196],\n", + " ...,\n", + " [-0.5065, -0.5065, -0.5215, ..., -0.6715, -0.6715, -0.6715],\n", + " [-0.5215, -0.5215, -0.5065, ..., -0.6865, -0.6865, -0.6865],\n", + " [-0.5215, -0.5665, 0.1689, ..., -0.7016, -0.7016, -0.7166]],\n", + "\n", + " [[ 1.9610, 1.9610, 1.9610, ..., 1.9042, 1.9042, 1.8899],\n", + " [ 1.9610, 1.9610, 1.9753, ..., 1.9042, 1.9042, 1.8899],\n", + " [ 1.9753, 1.9753, 1.9753, ..., 1.9042, 1.9042, 1.9042],\n", + " ...,\n", + " [-0.1009, -0.0724, -0.0867, ..., -0.2573, -0.2573, -0.2573],\n", + " [-0.1009, -0.1009, -0.0867, ..., -0.2715, -0.2715, -0.2715],\n", + " [-0.1009, -0.1578, 0.5390, ..., -0.2857, -0.2857, -0.3000]]]])}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlavaConfig {\n", + " \"_name_or_path\": \"/Users/noahkasmanoff/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/05ae2434cbb430be33edcba0c5203e7023f785b7\",\n", + " \"architectures\": [\n", + " \"LlavaForConditionalGeneration\"\n", + " ],\n", + " \"ignore_index\": -100,\n", + " \"image_token_index\": 32000,\n", + " \"model_type\": \"llava\",\n", + " \"pad_token_id\": 32001,\n", + " \"projector_hidden_act\": \"gelu\",\n", + " \"text_config\": {\n", + " \"_name_or_path\": \"lmsys/vicuna-7b-v1.5\",\n", + " \"architectures\": [\n", + " \"LlamaForCausalLM\"\n", + " ],\n", + " \"max_position_embeddings\": 4096,\n", + " \"model_type\": \"llama\",\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"torch_dtype\": \"float16\",\n", + " \"vocab_size\": 32064\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"float16\",\n", + " \"transformers_version\": \"4.37.2\",\n", + " \"vision_config\": {\n", + " \"hidden_size\": 1024,\n", + " \"image_size\": 336,\n", + " \"intermediate_size\": 4096,\n", + " \"model_type\": \"clip_vision_model\",\n", + " \"num_attention_heads\": 16,\n", + " \"num_hidden_layers\": 24,\n", + " \"patch_size\": 14,\n", + " \"projection_dim\": 768,\n", + " \"vocab_size\": 32000\n", + " },\n", + " \"vision_feature_layer\": -2,\n", + " \"vision_feature_select_strategy\": \"default\",\n", + " \"vocab_size\": 32064\n", + "}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/llava/base.py b/llava/base.py new file mode 100644 index 000000000..e69de29bb diff --git a/llava/llama.py b/llava/llama.py index 7b0052355..ba5baeda5 100644 --- a/llava/llama.py +++ b/llava/llama.py @@ -1,31 +1,52 @@ -# Copyright © 2023 Apple Inc. - -import argparse -import glob -import json -import time from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_unflatten -from sentencepiece import SentencePieceProcessor + +import inspect + + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) @dataclass -class ModelArgs: - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float vocab_size: int - rope_theta: float - rope_traditional: bool = True + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError( + f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError( + "rope_scaling 'type' currently only supports 'linear'") class RMSNorm(nn.Module): @@ -45,23 +66,31 @@ def __call__(self, x): class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.args = args - self.n_heads: int = args.n_heads - self.n_kv_heads: int = args.n_kv_heads + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = self.n_heads // self.n_kv_heads + self.repeats = n_heads // n_kv_heads - self.scale = self.args.head_dim**-0.5 + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 - self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) - self.wk = nn.Linear(args.dim, args.n_kv_heads * - args.head_dim, bias=False) - self.wv = nn.Linear(args.dim, args.n_kv_heads * - args.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) self.rope = nn.RoPE( - args.head_dim, traditional=args.rope_traditional, base=args.rope_theta + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, ) def __call__( @@ -69,10 +98,10 @@ def __call__( x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + ) -> mx.array: B, L, D = x.shape - queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) @@ -80,11 +109,9 @@ def __call__( values = values.reshape( B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - keys, values = map(repeat, (keys, values)) + if self.repeats > 1: + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache @@ -102,30 +129,30 @@ def repeat(a): scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.wo(output), (keys, values) + return self.o_proj(output), (keys, values) -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): super().__init__() - - self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + args.hidden_size, eps=args.rms_norm_eps) self.args = args def __call__( @@ -134,9 +161,9 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.attention(self.attention_norm(x), mask, cache) + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r - r = self.feed_forward(self.ffn_norm(h)) + r = self.mlp(self.post_attention_layernorm(h)) out = h + r return out, cache @@ -146,259 +173,58 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = [TransformerBlock(args=args) - for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - def __call__(self, x): - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.tok_embeddings.weight.dtype) - - x = self.tok_embeddings(x) - for l in self.layers: - x, _ = l(x, mask) - x = self.norm(x) - return self.output(x) - - def generate(self, x, temp=1.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - cache = [] - - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.tok_embeddings.weight.dtype) - - # First we process the prompt x the same was as in __call__ but - # save the caches in cache - x = self.tok_embeddings(x) - for l in self.layers: - x, c = l(x, mask=mask) - # We store the per layer cache in a simple python list - cache.append(c) - x = self.norm(x) - # We only care about the last logits that generate the next token - y = self.output(x[:, -1]) - y = sample(y) - - # y now has size [1] - # Since MLX is lazily evaluated nothing is computed yet. - # Calling y.item() would force the computation to happen at - # this point but we can also choose not to do that and let the - # user choose when to start the computation. - yield y - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = y[:, None] - - x = self.tok_embeddings(x) - for i in range(len(cache)): - # We are overwriting the arrays in the cache list. When - # the computation will happen, MLX will be discarding the - # old cache the moment it is not needed anymore. - x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) - x = self.norm(x) - y = sample(self.output(x[:, -1])) - - yield y - - -def tic(): - return time.time() - - -def toc(msg, start): - end = time.time() - return f"[INFO] {msg}: {end - start:.3f} s" - - -def generate(args): - input("Press enter to start generation") - print("------") - print(args.prompt) - x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) - skip = 0 - prompt_processing = None - tokens = [] - start = tic() - for token in model.generate(x, args.temp): - tokens.append(token) - - if len(tokens) == 1: - # Actually perform the computation to measure the prompt processing time - mx.eval(token) - prompt_processing = toc("Prompt processing", start) - - if len(tokens) >= args.max_tokens: - break - - elif (len(tokens) % args.write_every) == 0: - # It is perfectly ok to eval things we have already eval-ed. - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - skip = len(s) - - mx.eval(tokens) - full_gen = toc("Full generation", start) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], flush=True) - print("------") - print(prompt_processing) - print(full_gen) - - -def few_shot_generate(args): - def possible_end(s): - word = "[Instruction]" - for i in range(len(word) - 1, 0, -1): - if s[-i:] == word[:i]: - return 0 - if s[-len(word):] == word: - return 1 - return -1 - - def generate(question): - x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)]) - skip = 0 - prompt_processing = None - tokens = [] - start = tic() - for token in model.generate(x, args.temp): - tokens.append(token) - - if len(tokens) == 1: - # Actually perform the computation to measure the prompt processing time - mx.eval(token) - prompt_processing = toc("Prompt processing", start) - - if len(tokens) >= args.max_tokens: - break - - mx.eval(tokens) - token_list = [t.item() for t in tokens] - s = tokenizer.decode(token_list) - - end = possible_end(s) - if end == 0: - continue - if end == 1: - skip = len(s) - break - - print(s[skip:], end="", flush=True) - skip = len(s) - if token_list[-1] == tokenizer.eos_id(): - break - - mx.eval(tokens) - full_gen = toc("Full generation", start) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - - print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) - prompt = open(args.few_shot).read().strip() - while True: - question = input("Ask a question: ") - generate(prompt.replace("{}", question)) - print() - - -def sanitize_config(config, weights): - config.pop("model_type", None) - n_heads = config["n_heads"] - if "n_kv_heads" not in config: - config["n_kv_heads"] = n_heads - if "head_dim" not in config: - config["head_dim"] = config["dim"] // n_heads - if "hidden_dim" not in config: - config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] - if config.get("vocab_size", -1) < 0: - config["vocab_size"] = weights["output.weight"].shape[-1] - if "rope_theta" not in config: - config["rope_theta"] = 10000 - unused = ["multiple_of", "ffn_dim_multiplier"] - for k in unused: - config.pop(k, None) - return config - - -def load_model(model_path): - model_path = Path(model_path) - - unsharded_weights_path = Path(model_path / "weights.npz") - if unsharded_weights_path.is_file(): - print("[INFO] Loading model from {}.".format(unsharded_weights_path)) - weights = mx.load(str(unsharded_weights_path)) - else: - sharded_weights_glob = str(model_path / "weights.*.npz") - weight_files = glob.glob(sharded_weights_glob) - print("[INFO] Loading model from {}.".format(sharded_weights_glob)) - - if len(weight_files) == 0: - raise FileNotFoundError( - "No weights found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - with open(model_path / "config.json", "r") as f: - config = sanitize_config(json.loads(f.read()), weights) - quantization = config.pop("quantization", None) - model = Llama(ModelArgs(**config)) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - model.update(tree_unflatten(list(weights.items()))) - tokenizer = SentencePieceProcessor( - model_file=str(model_path / "tokenizer.model")) - return model, tokenizer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Llama inference script") - parser.add_argument( - "--model-path", - help="Path to the model weights and tokenizer", - default="mlx_model", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model. Ignored when --few-shot is provided.", - default="In the beginning the Universe was created.", - ) - parser.add_argument( - "--few-shot", - help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", - ) - parser.add_argument( - "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" - ) - parser.add_argument( - "--write-every", type=int, default=1, help="After how many tokens to detokenize" - ) - parser.add_argument( - "--temp", type=float, default=0.0, help="The sampling temperature" - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - - args = parser.parse_args() - - mx.random.seed(args.seed) - - model, tokenizer = load_model(args.model_path) - if args.few_shot: - few_shot_generate(args) - else: - generate(args) + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask( + h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = Llama(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/llava/llava.py b/llava/llava.py index 63c70db59..c661f9a18 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -1,10 +1,10 @@ from clip import CLIPVisionModel -from llama import Llama +from llama import LlamaModel from pathlib import Path import json import mlx.nn as nn import mlx.core as mx -from typing import Any, Optional +from typing import Any, Optional, Dict, Union from dataclasses import dataclass @@ -23,16 +23,17 @@ class VisionConfig: @dataclass class LLMConfig: - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float vocab_size: int - rope_theta: float - rope_traditional: bool = True + num_key_value_heads: int + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None @dataclass @@ -65,7 +66,7 @@ def forward(self, x: mx.array) -> mx.array: class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): self.vision_tower = CLIPVisionModel(config=config.vision_config) - self.language_model = Llama(args=config.llm_config) + self.language_model = LlamaModel(args=config.llm_config) self.multi_modal_projector = LlavaMultiModalProjector( config=config.projection_config) diff --git a/llava/utils.py b/llava/utils.py index 27d8f6f58..6f0bf7251 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -51,22 +51,22 @@ def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch. def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - key = key.replace('language_model.model.', 'language_model.') - key = key.replace('mlp.', 'feed_forward.') - key = key.replace("down_proj", "w2") - key = key.replace("up_proj", "w3") - key = key.replace("gate_proj", "w1") - key = key.replace("input_layernorm", "attention_norm") - key = key.replace("post_attention_layernorm", "ffn_norm") - key = key.replace("lm_head", "output") - - key = key.replace("embed_tokens", "tok_embeddings") - key = key.replace("self_attn", "attention") - - key = key.replace("q_proj", "wq") - key = key.replace("k_proj", "wk") - key = key.replace("v_proj", "wv") - key = key.replace("o_proj", "wo") + # key = key.replace('language_model.model.', 'language_model.') + # key = key.replace('mlp.', 'feed_forward.') + # key = key.replace("down_proj", "w2") + # key = key.replace("up_proj", "w3") + # key = key.replace("gate_proj", "w1") + # key = key.replace("input_layernorm", "attention_norm") + # key = key.replace("post_attention_layernorm", "ffn_norm") + # key = key.replace("lm_head", "output") + + # key = key.replace("embed_tokens", "tok_embeddings") + # key = key.replace("self_attn", "attention") + + # key = key.replace("q_proj", "wq") + # key = key.replace("k_proj", "wk") + # key = key.replace("v_proj", "wv") + # key = key.replace("o_proj", "wo") return (key, value) From b83b1e5c6f242b1c604a378101b610a677a8d538 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Wed, 21 Feb 2024 19:22:21 -0500 Subject: [PATCH 05/34] delete base --- llava/base.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 llava/base.py diff --git a/llava/base.py b/llava/base.py deleted file mode 100644 index e69de29bb..000000000 From 6e238470b929e2f31c19f511c7f831d6cbd6b537 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Wed, 21 Feb 2024 20:12:54 -0500 Subject: [PATCH 06/34] latest --- llava/Local LLava.ipynb | 256 ++++++++-------------------------------- llava/llava.py | 15 ++- 2 files changed, 62 insertions(+), 209 deletions(-) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index 0622420da..086f634d6 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -48,7 +48,7 @@ "text": [ "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 202950.19it/s]\n", + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 214177.23it/s]\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } @@ -64,12 +64,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Converting\n", + "[INFO] Saving\n" + ] + } + ], "source": [ "from utils import map_weights, should_keep_weight\n", - "do_convert = False\n", + "do_convert = True\n", "if do_convert:\n", "\n", " print(\"[INFO] Converting\")\n", @@ -87,14 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -145,7 +147,7 @@ "2.6875" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -154,261 +156,107 @@ "11008 / 4096" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "Received parameters not in model: language_model.layers.20.feed_forward.w3.weight language_model.layers.18.feed_forward.w2.weight language_model.layers.20.feed_forward.w1.weight language_model.layers.0.feed_forward.w1.weight language_model.layers.23.attention.wq.weight language_model.layers.23.ffn_norm.weight language_model.layers.16.feed_forward.w1.weight language_model.layers.31.attention_norm.weight language_model.output.weight language_model.layers.27.attention_norm.weight language_model.layers.25.attention_norm.weight language_model.layers.28.attention_norm.weight language_model.layers.20.feed_forward.w2.weight language_model.layers.17.attention.wo.weight language_model.layers.1.attention.wq.weight language_model.layers.27.feed_forward.w3.weight language_model.layers.19.feed_forward.w1.weight language_model.layers.14.attention_norm.weight language_model.layers.21.feed_forward.w2.weight language_model.layers.16.attention.wo.weight language_model.layers.22.attention_norm.weight language_model.layers.4.attention.wk.weight language_model.layers.13.feed_forward.w1.weight language_model.layers.30.attention.wv.weight language_model.layers.5.feed_forward.w3.weight language_model.layers.20.attention_norm.weight language_model.layers.13.feed_forward.w2.weight language_model.layers.22.feed_forward.w2.weight language_model.layers.15.attention.wo.weight language_model.layers.26.attention.wo.weight language_model.layers.5.feed_forward.w1.weight language_model.layers.16.attention_norm.weight language_model.layers.4.attention.wq.weight language_model.layers.9.feed_forward.w1.weight language_model.layers.20.attention.wq.weight language_model.layers.9.attention.wv.weight language_model.layers.10.ffn_norm.weight language_model.layers.8.attention.wo.weight language_model.layers.3.attention.wv.weight language_model.layers.0.ffn_norm.weight language_model.layers.4.feed_forward.w3.weight language_model.layers.2.attention.wv.weight language_model.layers.7.attention.wv.weight language_model.layers.24.attention.wq.weight language_model.layers.11.feed_forward.w2.weight language_model.layers.0.attention.wo.weight language_model.layers.7.feed_forward.w3.weight language_model.layers.17.feed_forward.w1.weight language_model.layers.31.attention.wo.weight language_model.layers.26.attention.wk.weight language_model.layers.0.feed_forward.w3.weight language_model.layers.2.ffn_norm.weight language_model.layers.13.attention_norm.weight language_model.layers.19.attention.wk.weight language_model.layers.18.attention.wq.weight language_model.layers.10.attention.wo.weight language_model.layers.30.attention.wq.weight language_model.layers.5.feed_forward.w2.weight language_model.layers.5.attention.wv.weight language_model.layers.25.attention.wq.weight language_model.layers.3.feed_forward.w3.weight language_model.layers.9.attention.wo.weight language_model.layers.29.feed_forward.w1.weight language_model.layers.2.feed_forward.w3.weight language_model.layers.0.attention.wk.weight language_model.layers.11.attention.wv.weight language_model.layers.20.attention.wo.weight language_model.layers.16.ffn_norm.weight language_model.layers.2.feed_forward.w2.weight language_model.layers.27.attention.wk.weight language_model.tok_embeddings.weight language_model.layers.14.ffn_norm.weight language_model.layers.12.ffn_norm.weight language_model.layers.22.attention.wo.weight language_model.layers.12.attention.wq.weight language_model.layers.19.attention.wq.weight language_model.layers.11.attention.wq.weight language_model.layers.6.attention.wv.weight language_model.layers.26.feed_forward.w3.weight language_model.layers.26.feed_forward.w2.weight language_model.layers.17.attention.wq.weight language_model.layers.18.feed_forward.w3.weight language_model.layers.29.attention.wk.weight language_model.layers.29.feed_forward.w2.weight language_model.layers.8.attention.wq.weight language_model.layers.2.attention_norm.weight language_model.layers.5.attention.wo.weight language_model.layers.23.feed_forward.w1.weight language_model.layers.30.feed_forward.w1.weight language_model.layers.2.attention.wo.weight language_model.layers.18.attention.wk.weight language_model.layers.13.attention.wo.weight language_model.layers.3.ffn_norm.weight language_model.layers.23.attention_norm.weight language_model.layers.10.feed_forward.w2.weight language_model.layers.1.ffn_norm.weight language_model.layers.21.ffn_norm.weight language_model.layers.30.attention.wo.weight language_model.layers.11.attention.wk.weight language_model.layers.7.attention.wo.weight language_model.layers.17.feed_forward.w3.weight language_model.layers.13.ffn_norm.weight language_model.layers.3.feed_forward.w1.weight language_model.layers.18.attention.wo.weight language_model.layers.22.feed_forward.w1.weight language_model.layers.15.attention.wk.weight language_model.layers.15.attention.wv.weight language_model.layers.28.feed_forward.w2.weight language_model.layers.21.feed_forward.w1.weight language_model.layers.12.feed_forward.w2.weight language_model.layers.23.feed_forward.w3.weight language_model.layers.19.ffn_norm.weight language_model.layers.18.attention_norm.weight language_model.layers.22.feed_forward.w3.weight language_model.layers.14.attention.wo.weight language_model.layers.9.ffn_norm.weight language_model.layers.13.attention.wk.weight language_model.layers.28.attention.wo.weight language_model.layers.26.attention.wq.weight language_model.layers.24.ffn_norm.weight language_model.layers.23.attention.wo.weight language_model.layers.10.attention_norm.weight language_model.layers.16.feed_forward.w2.weight language_model.layers.19.feed_forward.w2.weight language_model.layers.23.attention.wk.weight language_model.layers.2.feed_forward.w1.weight language_model.layers.11.feed_forward.w1.weight language_model.layers.4.feed_forward.w2.weight language_model.layers.23.attention.wv.weight language_model.layers.27.feed_forward.w1.weight language_model.layers.17.feed_forward.w2.weight language_model.layers.12.attention_norm.weight language_model.layers.30.feed_forward.w2.weight language_model.layers.15.ffn_norm.weight language_model.layers.12.feed_forward.w1.weight language_model.layers.28.attention.wq.weight language_model.layers.10.attention.wq.weight language_model.layers.4.ffn_norm.weight language_model.layers.14.feed_forward.w1.weight language_model.layers.3.feed_forward.w2.weight language_model.layers.12.feed_forward.w3.weight language_model.layers.21.attention.wq.weight language_model.layers.10.attention.wv.weight language_model.layers.21.feed_forward.w3.weight language_model.layers.6.feed_forward.w2.weight language_model.layers.20.ffn_norm.weight language_model.layers.25.attention.wk.weight language_model.layers.17.attention.wk.weight language_model.layers.26.attention_norm.weight language_model.layers.25.feed_forward.w2.weight language_model.layers.1.attention_norm.weight language_model.layers.26.attention.wv.weight language_model.layers.19.attention.wo.weight language_model.layers.14.feed_forward.w3.weight language_model.layers.14.attention.wv.weight language_model.layers.29.ffn_norm.weight language_model.layers.14.feed_forward.w2.weight language_model.layers.1.attention.wk.weight language_model.layers.4.attention.wv.weight language_model.layers.22.attention.wq.weight language_model.layers.3.attention.wq.weight language_model.layers.16.attention.wv.weight language_model.layers.21.attention.wo.weight language_model.layers.26.ffn_norm.weight language_model.layers.29.attention.wq.weight language_model.layers.7.attention.wq.weight language_model.layers.21.attention_norm.weight language_model.layers.24.attention.wo.weight language_model.layers.5.attention_norm.weight language_model.layers.18.feed_forward.w1.weight language_model.layers.26.feed_forward.w1.weight language_model.layers.31.attention.wv.weight language_model.layers.25.feed_forward.w1.weight language_model.layers.27.ffn_norm.weight language_model.layers.6.feed_forward.w1.weight language_model.layers.28.feed_forward.w1.weight language_model.layers.1.feed_forward.w3.weight language_model.layers.8.feed_forward.w2.weight language_model.layers.20.attention.wk.weight language_model.layers.2.attention.wq.weight language_model.layers.4.feed_forward.w1.weight language_model.layers.9.attention.wq.weight language_model.layers.15.feed_forward.w1.weight language_model.layers.7.ffn_norm.weight language_model.layers.0.feed_forward.w2.weight language_model.layers.30.attention_norm.weight language_model.layers.13.attention.wv.weight language_model.layers.10.feed_forward.w1.weight language_model.layers.5.attention.wq.weight language_model.layers.16.feed_forward.w3.weight language_model.layers.28.ffn_norm.weight language_model.layers.31.feed_forward.w1.weight language_model.layers.12.attention.wo.weight language_model.layers.27.attention.wo.weight language_model.layers.15.feed_forward.w3.weight language_model.layers.29.attention.wo.weight language_model.layers.27.attention.wv.weight language_model.layers.14.attention.wq.weight language_model.layers.5.attention.wk.weight language_model.layers.1.feed_forward.w1.weight language_model.layers.20.attention.wv.weight language_model.layers.23.feed_forward.w2.weight language_model.layers.8.attention.wk.weight language_model.layers.5.ffn_norm.weight language_model.layers.21.attention.wv.weight language_model.layers.29.attention_norm.weight language_model.layers.10.feed_forward.w3.weight language_model.layers.1.feed_forward.w2.weight language_model.layers.24.feed_forward.w3.weight language_model.layers.11.ffn_norm.weight language_model.layers.9.attention_norm.weight language_model.layers.4.attention.wo.weight language_model.layers.25.attention.wo.weight language_model.layers.7.feed_forward.w2.weight language_model.layers.9.feed_forward.w2.weight language_model.layers.14.attention.wk.weight language_model.layers.27.feed_forward.w2.weight language_model.layers.13.attention.wq.weight language_model.layers.15.attention_norm.weight language_model.layers.28.attention.wv.weight language_model.layers.0.attention_norm.weight language_model.layers.0.attention.wv.weight language_model.layers.7.attention.wk.weight language_model.layers.29.feed_forward.w3.weight language_model.layers.3.attention.wk.weight language_model.layers.28.feed_forward.w3.weight language_model.layers.22.attention.wv.weight language_model.layers.22.attention.wk.weight language_model.layers.6.attention.wq.weight language_model.layers.1.attention.wo.weight language_model.layers.18.attention.wv.weight language_model.layers.8.attention.wv.weight language_model.layers.6.ffn_norm.weight language_model.layers.25.ffn_norm.weight language_model.layers.8.attention_norm.weight language_model.layers.6.attention.wk.weight language_model.layers.29.attention.wv.weight language_model.layers.19.attention_norm.weight language_model.layers.19.attention.wv.weight language_model.layers.6.attention.wo.weight language_model.layers.12.attention.wk.weight language_model.layers.9.feed_forward.w3.weight language_model.layers.8.feed_forward.w1.weight language_model.layers.10.attention.wk.weight language_model.layers.17.ffn_norm.weight language_model.layers.21.attention.wk.weight language_model.layers.15.attention.wq.weight language_model.layers.11.attention_norm.weight language_model.layers.24.attention.wk.weight language_model.layers.31.feed_forward.w2.weight language_model.layers.18.ffn_norm.weight language_model.layers.30.feed_forward.w3.weight language_model.layers.22.ffn_norm.weight language_model.layers.28.attention.wk.weight language_model.layers.9.attention.wk.weight language_model.layers.24.feed_forward.w2.weight language_model.layers.17.attention_norm.weight language_model.layers.17.attention.wv.weight language_model.layers.1.attention.wv.weight language_model.layers.31.ffn_norm.weight language_model.layers.31.attention.wk.weight language_model.layers.24.feed_forward.w1.weight language_model.layers.8.feed_forward.w3.weight language_model.layers.25.attention.wv.weight language_model.layers.7.feed_forward.w1.weight language_model.layers.31.attention.wq.weight language_model.layers.15.feed_forward.w2.weight language_model.layers.30.ffn_norm.weight language_model.layers.0.attention.wq.weight language_model.layers.31.feed_forward.w3.weight language_model.layers.13.feed_forward.w3.weight language_model.layers.19.feed_forward.w3.weight language_model.layers.6.attention_norm.weight language_model.layers.4.attention_norm.weight language_model.layers.12.attention.wv.weight language_model.layers.8.ffn_norm.weight language_model.layers.30.attention.wk.weight language_model.layers.3.attention.wo.weight language_model.layers.16.attention.wq.weight language_model.layers.11.feed_forward.w3.weight language_model.layers.25.feed_forward.w3.weight language_model.layers.3.attention_norm.weight language_model.layers.2.attention.wk.weight language_model.layers.16.attention.wk.weight language_model.layers.7.attention_norm.weight language_model.layers.27.attention.wq.weight language_model.layers.6.feed_forward.w3.weight language_model.layers.24.attention.wv.weight language_model.layers.11.attention.wo.weight language_model.layers.24.attention_norm.weight.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmlx_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_weights\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmlx_model/weights.npz\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/mlx/lib/python3.10/site-packages/mlx/nn/layers/base.py:164\u001b[0m, in \u001b[0;36mModule.load_weights\u001b[0;34m(self, file_or_weights, strict)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m extras \u001b[38;5;241m:=\u001b[39m (new_weights\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;241m-\u001b[39m curr_weights\u001b[38;5;241m.\u001b[39mkeys()):\n\u001b[1;32m 163\u001b[0m extras \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(extras)\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReceived parameters not in model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mextras\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m missing \u001b[38;5;241m:=\u001b[39m (curr_weights\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;241m-\u001b[39m new_weights\u001b[38;5;241m.\u001b[39mkeys()):\n\u001b[1;32m 166\u001b[0m missing \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing)\n", - "\u001b[0;31mValueError\u001b[0m: Received parameters not in model: language_model.layers.20.feed_forward.w3.weight language_model.layers.18.feed_forward.w2.weight language_model.layers.20.feed_forward.w1.weight language_model.layers.0.feed_forward.w1.weight language_model.layers.23.attention.wq.weight language_model.layers.23.ffn_norm.weight language_model.layers.16.feed_forward.w1.weight language_model.layers.31.attention_norm.weight language_model.output.weight language_model.layers.27.attention_norm.weight language_model.layers.25.attention_norm.weight language_model.layers.28.attention_norm.weight language_model.layers.20.feed_forward.w2.weight language_model.layers.17.attention.wo.weight language_model.layers.1.attention.wq.weight language_model.layers.27.feed_forward.w3.weight language_model.layers.19.feed_forward.w1.weight language_model.layers.14.attention_norm.weight language_model.layers.21.feed_forward.w2.weight language_model.layers.16.attention.wo.weight language_model.layers.22.attention_norm.weight language_model.layers.4.attention.wk.weight language_model.layers.13.feed_forward.w1.weight language_model.layers.30.attention.wv.weight language_model.layers.5.feed_forward.w3.weight language_model.layers.20.attention_norm.weight language_model.layers.13.feed_forward.w2.weight language_model.layers.22.feed_forward.w2.weight language_model.layers.15.attention.wo.weight language_model.layers.26.attention.wo.weight language_model.layers.5.feed_forward.w1.weight language_model.layers.16.attention_norm.weight language_model.layers.4.attention.wq.weight language_model.layers.9.feed_forward.w1.weight language_model.layers.20.attention.wq.weight language_model.layers.9.attention.wv.weight language_model.layers.10.ffn_norm.weight language_model.layers.8.attention.wo.weight language_model.layers.3.attention.wv.weight language_model.layers.0.ffn_norm.weight language_model.layers.4.feed_forward.w3.weight language_model.layers.2.attention.wv.weight language_model.layers.7.attention.wv.weight language_model.layers.24.attention.wq.weight language_model.layers.11.feed_forward.w2.weight language_model.layers.0.attention.wo.weight language_model.layers.7.feed_forward.w3.weight language_model.layers.17.feed_forward.w1.weight language_model.layers.31.attention.wo.weight language_model.layers.26.attention.wk.weight language_model.layers.0.feed_forward.w3.weight language_model.layers.2.ffn_norm.weight language_model.layers.13.attention_norm.weight language_model.layers.19.attention.wk.weight language_model.layers.18.attention.wq.weight language_model.layers.10.attention.wo.weight language_model.layers.30.attention.wq.weight language_model.layers.5.feed_forward.w2.weight language_model.layers.5.attention.wv.weight language_model.layers.25.attention.wq.weight language_model.layers.3.feed_forward.w3.weight language_model.layers.9.attention.wo.weight language_model.layers.29.feed_forward.w1.weight language_model.layers.2.feed_forward.w3.weight language_model.layers.0.attention.wk.weight language_model.layers.11.attention.wv.weight language_model.layers.20.attention.wo.weight language_model.layers.16.ffn_norm.weight language_model.layers.2.feed_forward.w2.weight language_model.layers.27.attention.wk.weight language_model.tok_embeddings.weight language_model.layers.14.ffn_norm.weight language_model.layers.12.ffn_norm.weight language_model.layers.22.attention.wo.weight language_model.layers.12.attention.wq.weight language_model.layers.19.attention.wq.weight language_model.layers.11.attention.wq.weight language_model.layers.6.attention.wv.weight language_model.layers.26.feed_forward.w3.weight language_model.layers.26.feed_forward.w2.weight language_model.layers.17.attention.wq.weight language_model.layers.18.feed_forward.w3.weight language_model.layers.29.attention.wk.weight language_model.layers.29.feed_forward.w2.weight language_model.layers.8.attention.wq.weight language_model.layers.2.attention_norm.weight language_model.layers.5.attention.wo.weight language_model.layers.23.feed_forward.w1.weight language_model.layers.30.feed_forward.w1.weight language_model.layers.2.attention.wo.weight language_model.layers.18.attention.wk.weight language_model.layers.13.attention.wo.weight language_model.layers.3.ffn_norm.weight language_model.layers.23.attention_norm.weight language_model.layers.10.feed_forward.w2.weight language_model.layers.1.ffn_norm.weight language_model.layers.21.ffn_norm.weight language_model.layers.30.attention.wo.weight language_model.layers.11.attention.wk.weight language_model.layers.7.attention.wo.weight language_model.layers.17.feed_forward.w3.weight language_model.layers.13.ffn_norm.weight language_model.layers.3.feed_forward.w1.weight language_model.layers.18.attention.wo.weight language_model.layers.22.feed_forward.w1.weight language_model.layers.15.attention.wk.weight language_model.layers.15.attention.wv.weight language_model.layers.28.feed_forward.w2.weight language_model.layers.21.feed_forward.w1.weight language_model.layers.12.feed_forward.w2.weight language_model.layers.23.feed_forward.w3.weight language_model.layers.19.ffn_norm.weight language_model.layers.18.attention_norm.weight language_model.layers.22.feed_forward.w3.weight language_model.layers.14.attention.wo.weight language_model.layers.9.ffn_norm.weight language_model.layers.13.attention.wk.weight language_model.layers.28.attention.wo.weight language_model.layers.26.attention.wq.weight language_model.layers.24.ffn_norm.weight language_model.layers.23.attention.wo.weight language_model.layers.10.attention_norm.weight language_model.layers.16.feed_forward.w2.weight language_model.layers.19.feed_forward.w2.weight language_model.layers.23.attention.wk.weight language_model.layers.2.feed_forward.w1.weight language_model.layers.11.feed_forward.w1.weight language_model.layers.4.feed_forward.w2.weight language_model.layers.23.attention.wv.weight language_model.layers.27.feed_forward.w1.weight language_model.layers.17.feed_forward.w2.weight language_model.layers.12.attention_norm.weight language_model.layers.30.feed_forward.w2.weight language_model.layers.15.ffn_norm.weight language_model.layers.12.feed_forward.w1.weight language_model.layers.28.attention.wq.weight language_model.layers.10.attention.wq.weight language_model.layers.4.ffn_norm.weight language_model.layers.14.feed_forward.w1.weight language_model.layers.3.feed_forward.w2.weight language_model.layers.12.feed_forward.w3.weight language_model.layers.21.attention.wq.weight language_model.layers.10.attention.wv.weight language_model.layers.21.feed_forward.w3.weight language_model.layers.6.feed_forward.w2.weight language_model.layers.20.ffn_norm.weight language_model.layers.25.attention.wk.weight language_model.layers.17.attention.wk.weight language_model.layers.26.attention_norm.weight language_model.layers.25.feed_forward.w2.weight language_model.layers.1.attention_norm.weight language_model.layers.26.attention.wv.weight language_model.layers.19.attention.wo.weight language_model.layers.14.feed_forward.w3.weight language_model.layers.14.attention.wv.weight language_model.layers.29.ffn_norm.weight language_model.layers.14.feed_forward.w2.weight language_model.layers.1.attention.wk.weight language_model.layers.4.attention.wv.weight language_model.layers.22.attention.wq.weight language_model.layers.3.attention.wq.weight language_model.layers.16.attention.wv.weight language_model.layers.21.attention.wo.weight language_model.layers.26.ffn_norm.weight language_model.layers.29.attention.wq.weight language_model.layers.7.attention.wq.weight language_model.layers.21.attention_norm.weight language_model.layers.24.attention.wo.weight language_model.layers.5.attention_norm.weight language_model.layers.18.feed_forward.w1.weight language_model.layers.26.feed_forward.w1.weight language_model.layers.31.attention.wv.weight language_model.layers.25.feed_forward.w1.weight language_model.layers.27.ffn_norm.weight language_model.layers.6.feed_forward.w1.weight language_model.layers.28.feed_forward.w1.weight language_model.layers.1.feed_forward.w3.weight language_model.layers.8.feed_forward.w2.weight language_model.layers.20.attention.wk.weight language_model.layers.2.attention.wq.weight language_model.layers.4.feed_forward.w1.weight language_model.layers.9.attention.wq.weight language_model.layers.15.feed_forward.w1.weight language_model.layers.7.ffn_norm.weight language_model.layers.0.feed_forward.w2.weight language_model.layers.30.attention_norm.weight language_model.layers.13.attention.wv.weight language_model.layers.10.feed_forward.w1.weight language_model.layers.5.attention.wq.weight language_model.layers.16.feed_forward.w3.weight language_model.layers.28.ffn_norm.weight language_model.layers.31.feed_forward.w1.weight language_model.layers.12.attention.wo.weight language_model.layers.27.attention.wo.weight language_model.layers.15.feed_forward.w3.weight language_model.layers.29.attention.wo.weight language_model.layers.27.attention.wv.weight language_model.layers.14.attention.wq.weight language_model.layers.5.attention.wk.weight language_model.layers.1.feed_forward.w1.weight language_model.layers.20.attention.wv.weight language_model.layers.23.feed_forward.w2.weight language_model.layers.8.attention.wk.weight language_model.layers.5.ffn_norm.weight language_model.layers.21.attention.wv.weight language_model.layers.29.attention_norm.weight language_model.layers.10.feed_forward.w3.weight language_model.layers.1.feed_forward.w2.weight language_model.layers.24.feed_forward.w3.weight language_model.layers.11.ffn_norm.weight language_model.layers.9.attention_norm.weight language_model.layers.4.attention.wo.weight language_model.layers.25.attention.wo.weight language_model.layers.7.feed_forward.w2.weight language_model.layers.9.feed_forward.w2.weight language_model.layers.14.attention.wk.weight language_model.layers.27.feed_forward.w2.weight language_model.layers.13.attention.wq.weight language_model.layers.15.attention_norm.weight language_model.layers.28.attention.wv.weight language_model.layers.0.attention_norm.weight language_model.layers.0.attention.wv.weight language_model.layers.7.attention.wk.weight language_model.layers.29.feed_forward.w3.weight language_model.layers.3.attention.wk.weight language_model.layers.28.feed_forward.w3.weight language_model.layers.22.attention.wv.weight language_model.layers.22.attention.wk.weight language_model.layers.6.attention.wq.weight language_model.layers.1.attention.wo.weight language_model.layers.18.attention.wv.weight language_model.layers.8.attention.wv.weight language_model.layers.6.ffn_norm.weight language_model.layers.25.ffn_norm.weight language_model.layers.8.attention_norm.weight language_model.layers.6.attention.wk.weight language_model.layers.29.attention.wv.weight language_model.layers.19.attention_norm.weight language_model.layers.19.attention.wv.weight language_model.layers.6.attention.wo.weight language_model.layers.12.attention.wk.weight language_model.layers.9.feed_forward.w3.weight language_model.layers.8.feed_forward.w1.weight language_model.layers.10.attention.wk.weight language_model.layers.17.ffn_norm.weight language_model.layers.21.attention.wk.weight language_model.layers.15.attention.wq.weight language_model.layers.11.attention_norm.weight language_model.layers.24.attention.wk.weight language_model.layers.31.feed_forward.w2.weight language_model.layers.18.ffn_norm.weight language_model.layers.30.feed_forward.w3.weight language_model.layers.22.ffn_norm.weight language_model.layers.28.attention.wk.weight language_model.layers.9.attention.wk.weight language_model.layers.24.feed_forward.w2.weight language_model.layers.17.attention_norm.weight language_model.layers.17.attention.wv.weight language_model.layers.1.attention.wv.weight language_model.layers.31.ffn_norm.weight language_model.layers.31.attention.wk.weight language_model.layers.24.feed_forward.w1.weight language_model.layers.8.feed_forward.w3.weight language_model.layers.25.attention.wv.weight language_model.layers.7.feed_forward.w1.weight language_model.layers.31.attention.wq.weight language_model.layers.15.feed_forward.w2.weight language_model.layers.30.ffn_norm.weight language_model.layers.0.attention.wq.weight language_model.layers.31.feed_forward.w3.weight language_model.layers.13.feed_forward.w3.weight language_model.layers.19.feed_forward.w3.weight language_model.layers.6.attention_norm.weight language_model.layers.4.attention_norm.weight language_model.layers.12.attention.wv.weight language_model.layers.8.ffn_norm.weight language_model.layers.30.attention.wk.weight language_model.layers.3.attention.wo.weight language_model.layers.16.attention.wq.weight language_model.layers.11.feed_forward.w3.weight language_model.layers.25.feed_forward.w3.weight language_model.layers.3.attention_norm.weight language_model.layers.2.attention.wk.weight language_model.layers.16.attention.wk.weight language_model.layers.7.attention_norm.weight language_model.layers.27.attention.wq.weight language_model.layers.6.feed_forward.w3.weight language_model.layers.24.attention.wv.weight language_model.layers.11.attention.wo.weight language_model.layers.24.attention_norm.weight." - ] - } - ], - "source": [ - "mlx_model.load_weights('mlx_model/weights.npz')\n" - ] - }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "# TODO: load images, and test generate " + "mlx_model.load_weights('mlx_model/weights.npz')\n" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", - "Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00, 2.48s/it]\n" + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ - "# TODO: compare with hf version's model weights as well \n", + "# Now that model weights are loaded in, now we can try and run inference code / set that up.\n", "\n", - "# Load model directly\n", - "from transformers import AutoProcessor, AutoModelForPreTraining\n", - "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n", - "model = AutoModelForPreTraining.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-0.00692749, -0.0147705, -0.00254822, ..., 0.00500488, 0.00238037, -0.0027771],\n", - " [0.0155029, -0.00343323, 0.00121307, ..., -0.00964355, -0.0110474, 0.00744629],\n", - " [-0.0157471, 0.0144043, 0.000104904, ..., 0.00619507, 0.0189209, -0.00415039],\n", - " ...,\n", - " [1.54972e-06, 0.00866699, 0.000881195, ..., 0.00946045, -0.0301514, 0.0107422],\n", - " [0.0253906, 0.00994873, 0.00454712, ..., -0.0319824, -0.0148926, -0.0130005],\n", - " [-0.0108643, -0.00534058, 0.00102234, ..., 0.0164795, 0.0150146, -0.00811768]], dtype=float16)" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mlx_model.language_model.layers[0].attention.wq.weight" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Parameter containing:\n", - "tensor([[-6.9275e-03, -1.4771e-02, -2.5482e-03, ..., 5.0049e-03,\n", - " 2.3804e-03, -2.7771e-03],\n", - " [ 1.5503e-02, -3.4332e-03, 1.2131e-03, ..., -9.6436e-03,\n", - " -1.1047e-02, 7.4463e-03],\n", - " [-1.5747e-02, 1.4404e-02, 1.0490e-04, ..., 6.1951e-03,\n", - " 1.8921e-02, -4.1504e-03],\n", - " ...,\n", - " [ 1.5497e-06, 8.6670e-03, 8.8120e-04, ..., 9.4604e-03,\n", - " -3.0151e-02, 1.0742e-02],\n", - " [ 2.5391e-02, 9.9487e-03, 4.5471e-03, ..., -3.1982e-02,\n", - " -1.4893e-02, -1.3000e-02],\n", - " [-1.0864e-02, -5.3406e-03, 1.0223e-03, ..., 1.6479e-02,\n", - " 1.5015e-02, -8.1177e-03]], requires_grad=True)" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.language_model.model.layers[0].self_attn.q_proj.weight" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# They seem to be the same!" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ + "# load the processor\n", + "from transformers import AutoProcessor\n", "import requests\n", "from PIL import Image\n", + "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n", "\n", - "image = Image.open(requests.get(\"https://llava-vl.github.io/static/images/view.jpg\", stream=True).raw)" + "prompt = \"\\nUSER: What's the content of the image?\\nASSISTANT:\"\n", + "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "\n", + "inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ - "prompts = ['USER: What are the things I should think aboutwhen I visit this place? ASSISTANT:'\n", - " ]" + "\n", + "input_ids = mx.array(inputs[\"input_ids\"].numpy())\n", + "pixel_values = mx.array(inputs[\"pixel_values\"].numpy())\n" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ - "inputs = processor(prompts, images=[image], padding=True, return_tensors=\"pt\")" + "vision_model_output = mlx_model.vision_tower(pixel_values.transpose(0,2,3,1))" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'input_ids': tensor([[ 1, 3148, 1001, 29901, 29871, 32000, 29871, 1724, 526, 278,\n", - " 2712, 306, 881, 1348, 1048, 8256, 306, 6493, 445, 2058,\n", - " 29973, 319, 1799, 9047, 13566, 29901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1]]), 'pixel_values': tensor([[[[ 1.2734, 1.2734, 1.2734, ..., 1.1274, 1.1274, 1.0982],\n", - " [ 1.2734, 1.2734, 1.2880, ..., 1.1274, 1.1274, 1.1128],\n", - " [ 1.2880, 1.2880, 1.2880, ..., 1.1274, 1.1274, 1.1274],\n", - " ...,\n", - " [-0.9456, -0.9164, -0.9164, ..., -1.0769, -1.0769, -1.0769],\n", - " [-0.9602, -0.9310, -0.9018, ..., -1.0915, -1.0915, -1.0915],\n", - " [-0.9602, -0.9748, -0.2448, ..., -1.1061, -1.1061, -1.1207]],\n", - "\n", - " [[ 1.6397, 1.6397, 1.6397, ..., 1.5196, 1.5196, 1.5196],\n", - " [ 1.6397, 1.6397, 1.6547, ..., 1.5196, 1.5196, 1.5196],\n", - " [ 1.6547, 1.6547, 1.6547, ..., 1.5196, 1.5196, 1.5196],\n", - " ...,\n", - " [-0.5065, -0.5065, -0.5215, ..., -0.6715, -0.6715, -0.6715],\n", - " [-0.5215, -0.5215, -0.5065, ..., -0.6865, -0.6865, -0.6865],\n", - " [-0.5215, -0.5665, 0.1689, ..., -0.7016, -0.7016, -0.7166]],\n", - "\n", - " [[ 1.9610, 1.9610, 1.9610, ..., 1.9042, 1.9042, 1.8899],\n", - " [ 1.9610, 1.9610, 1.9753, ..., 1.9042, 1.9042, 1.8899],\n", - " [ 1.9753, 1.9753, 1.9753, ..., 1.9042, 1.9042, 1.9042],\n", - " ...,\n", - " [-0.1009, -0.0724, -0.0867, ..., -0.2573, -0.2573, -0.2573],\n", - " [-0.1009, -0.1009, -0.0867, ..., -0.2715, -0.2715, -0.2715],\n", - " [-0.1009, -0.1578, 0.5390, ..., -0.2857, -0.2857, -0.3000]]]])}" + "(1, 577, 1024)" ] }, - "execution_count": 18, + "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "\n", - "inputs" - ] + "source": [] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "LlavaConfig {\n", - " \"_name_or_path\": \"/Users/noahkasmanoff/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/05ae2434cbb430be33edcba0c5203e7023f785b7\",\n", - " \"architectures\": [\n", - " \"LlavaForConditionalGeneration\"\n", - " ],\n", - " \"ignore_index\": -100,\n", - " \"image_token_index\": 32000,\n", - " \"model_type\": \"llava\",\n", - " \"pad_token_id\": 32001,\n", - " \"projector_hidden_act\": \"gelu\",\n", - " \"text_config\": {\n", - " \"_name_or_path\": \"lmsys/vicuna-7b-v1.5\",\n", - " \"architectures\": [\n", - " \"LlamaForCausalLM\"\n", - " ],\n", - " \"max_position_embeddings\": 4096,\n", - " \"model_type\": \"llama\",\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"torch_dtype\": \"float16\",\n", - " \"vocab_size\": 32064\n", - " },\n", - " \"tie_word_embeddings\": false,\n", - " \"torch_dtype\": \"float16\",\n", - " \"transformers_version\": \"4.37.2\",\n", - " \"vision_config\": {\n", - " \"hidden_size\": 1024,\n", - " \"image_size\": 336,\n", - " \"intermediate_size\": 4096,\n", - " \"model_type\": \"clip_vision_model\",\n", - " \"num_attention_heads\": 16,\n", - " \"num_hidden_layers\": 24,\n", - " \"patch_size\": 14,\n", - " \"projection_dim\": 768,\n", - " \"vocab_size\": 32000\n", - " },\n", - " \"vision_feature_layer\": -2,\n", - " \"vision_feature_select_strategy\": \"default\",\n", - " \"vocab_size\": 32064\n", - "}" + "CLIPVisionOutput(pooler_output=array([[-0.721487, -0.476275, 0.0173661, ..., 0.190072, -1.71528, 1.36224]], dtype=float32), last_hidden_state=array([[[-0.333623, -0.269844, 0.025435, ..., -0.0516554, -0.729696, 0.542679],\n", + " [0.208684, 0.92752, 0.0233985, ..., 1.59934, -0.024813, 0.879629],\n", + " [0.550235, 0.45201, 0.80935, ..., 1.63056, -0.37727, 0.699322],\n", + " ...,\n", + " [0.740987, 0.445616, 0.893172, ..., 0.523529, 0.0230118, -0.457155],\n", + " [0.49297, 0.0680847, 0.79401, ..., 0.476083, 0.274526, -0.284749],\n", + " [-0.0411091, 0.290756, 0.518906, ..., 0.242572, 0.40785, 0.420446]]], dtype=float32))" ] }, - "execution_count": 17, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config" + "vision_model_output" ] }, { diff --git a/llava/llava.py b/llava/llava.py index c661f9a18..a078775f7 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -72,20 +72,25 @@ def __init__(self, config: LlaVAConfig): def __call__(self, input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None, - attention_mask: Optional[mx.array] = None,): + pixel_values: Optional[mx.array] = None): # TODO: add the forward pass if pixel_values is not None and input_ids.shape[1] != 1: image_outputs = self.vision_tower(pixel_values) + # TODO: this is not the correct output layer, but it's a placeholder + selected_image_feature = image_outputs.pooler_output + image_features = self.multi_modal_projector( - image_outputs) - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels) + selected_image_feature) def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 + + special_image_token_mask = input_ids == self.config.special_tokens.image + + num_image_tokens = special_image_token_mask.sum() + pass @staticmethod From bb5b89831dda37ac150e59b7f9ec480959361270 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Thu, 22 Feb 2024 07:29:02 -0500 Subject: [PATCH 07/34] adding config --- llava/Local LLava.ipynb | 41 +++++++++++++++++++++-------------------- llava/clip.py | 5 ++++- llava/config.py | 37 +++++++++++++++++++++++++++++++++++++ llava/utils.py | 17 ----------------- 4 files changed, 62 insertions(+), 38 deletions(-) create mode 100644 llava/config.py diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index 086f634d6..e93ff1be8 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -101,37 +101,38 @@ "outputs": [], "source": [ "from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel\n", - "\n", + "from config import model_config # based on https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json\n", "llava_mlx_config = LlaVAConfig(\n", " llm_config=LLMConfig(\n", " model_type='vicuna',\n", - " hidden_size=4096,\n", - " num_hidden_layers=32,\n", - " intermediate_size=11008,\n", - " num_attention_heads=32,\n", - " rms_norm_eps=1e-5,\n", - " vocab_size=32064,\n", - " num_key_value_heads=32,\n", - " rope_theta=0,\n", - " rope_traditional=False,\n", - " rope_scaling=None\n", + " hidden_size=model_config['language_model']['hidden_size'],\n", + " num_hidden_layers=model_config['language_model']['num_hidden_layers'],\n", + " intermediate_size=model_config['language_model']['intermediate_size'],\n", + " num_attention_heads=model_config['language_model']['num_attention_heads'],\n", + " rms_norm_eps=model_config['language_model']['rms_norm_eps'],\n", + " vocab_size=model_config['language_model']['vocab_size'],\n", + " num_key_value_heads=model_config['language_model']['num_key_value_heads'],\n", + " rope_theta=model_config['language_model']['rope_theta'],\n", + " rope_traditional=model_config['language_model']['rope_traditional'],\n", + " rope_scaling=model_config['language_model']['rope_scaling'],\n", " ),\n", " vision_config=VisionConfig(\n", - " num_hidden_layers=24,\n", - " hidden_size=1024,\n", - " intermediate_size=4096,\n", - " num_attention_heads=16,\n", - " num_channels=3,\n", - " image_size=336,\n", - " patch_size=14\n", + " num_hidden_layers=model_config['vision_tower']['num_hidden_layers'],\n", + " hidden_size=model_config['vision_tower']['hidden_size'],\n", + " intermediate_size=model_config['vision_tower']['intermediate_size'],\n", + " num_attention_heads=model_config['vision_tower']['num_attention_heads'],\n", + " num_channels=model_config['vision_tower']['num_channels'],\n", + " image_size=model_config['vision_tower']['image_size'],\n", + " patch_size=model_config['vision_tower']['patch_size'],\n", " ),\n", " projection_config=ProjectionConfig(\n", - " in_features=1024,\n", - " out_features=4096\n", + " in_features=model_config['multi_modal_projector']['in_features'],\n", + " out_features=model_config['multi_modal_projector']['out_features'],\n", " )\n", ")\n", "\n", "\n", + "\n", "mlx_model = LlavaModel(llava_mlx_config)\n", "\n" ] diff --git a/llava/clip.py b/llava/clip.py index 6c46088d0..1568858a7 100644 --- a/llava/clip.py +++ b/llava/clip.py @@ -16,6 +16,7 @@ class CLIPVisionOutput: pooler_output: mx.array last_hidden_state: mx.array + llava_hidden_state: mx.array @dataclass @@ -185,7 +186,9 @@ def __call__(self, x: mx.array) -> CLIPVisionOutput: # Extract token embedding pooler_output = self.post_layernorm(x[:, 0, :]) - return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x) + + llava_hidden_state = x + return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x, llava_hidden_state=llava_hidden_state) class CLIPModel(nn.Module): diff --git a/llava/config.py b/llava/config.py new file mode 100644 index 000000000..f8617ac14 --- /dev/null +++ b/llava/config.py @@ -0,0 +1,37 @@ +model_config = { + 'language_model': { + 'hidden_size': 4096, + 'num_hidden_layers': 32, + 'intermediate_size': 11008, + 'num_attention_heads': 32, + 'rms_norm_eps': 1e-5, + 'vocab_size': 32000, + 'num_key_value_heads': 32, + 'rope_theta': 0, + 'rope_traditional': False, + 'rope_scaling': None}, + + 'vision_tower': { + 'num_hidden_layers': 24, + 'hidden_size': 1024, + 'intermediate_size': 4096, + 'num_attention_heads': 16, + 'num_channels': 3, + 'image_size': 336, + 'patch_size': 14 + }, + + 'multi_modal_projector': { + 'in_features': 1024, + 'out_features': 4096 + }, + + 'vision_feature_layer': -2, + 'vision_feature_selection_strategy': 'default', + 'image_token_index': 32000, + 'pad_token_id': 32001, + 'tie_word_embeddings': False, + 'vocab_size': 32064, # TODO: confirm this value + + +} diff --git a/llava/utils.py b/llava/utils.py index 6f0bf7251..62c5d0ef1 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -51,23 +51,6 @@ def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch. def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - # key = key.replace('language_model.model.', 'language_model.') - # key = key.replace('mlp.', 'feed_forward.') - # key = key.replace("down_proj", "w2") - # key = key.replace("up_proj", "w3") - # key = key.replace("gate_proj", "w1") - # key = key.replace("input_layernorm", "attention_norm") - # key = key.replace("post_attention_layernorm", "ffn_norm") - # key = key.replace("lm_head", "output") - - # key = key.replace("embed_tokens", "tok_embeddings") - # key = key.replace("self_attn", "attention") - - # key = key.replace("q_proj", "wq") - # key = key.replace("k_proj", "wk") - # key = key.replace("v_proj", "wv") - # key = key.replace("o_proj", "wo") - return (key, value) From 95f9df1034b61cbf9245889ae2ad06948870e616 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Thu, 22 Feb 2024 08:16:10 -0500 Subject: [PATCH 08/34] fix: use config --- llava/Local LLava.ipynb | 935 +++++++++++++++++++++++++++++++++++++--- llava/llava.py | 37 +- 2 files changed, 915 insertions(+), 57 deletions(-) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb index e93ff1be8..69c760103 100644 --- a/llava/Local LLava.ipynb +++ b/llava/Local LLava.ipynb @@ -18,13 +18,6 @@ "import os\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 2, @@ -96,74 +89,910 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel\n", - "from config import model_config # based on https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json\n", - "llava_mlx_config = LlaVAConfig(\n", - " llm_config=LLMConfig(\n", - " model_type='vicuna',\n", - " hidden_size=model_config['language_model']['hidden_size'],\n", - " num_hidden_layers=model_config['language_model']['num_hidden_layers'],\n", - " intermediate_size=model_config['language_model']['intermediate_size'],\n", - " num_attention_heads=model_config['language_model']['num_attention_heads'],\n", - " rms_norm_eps=model_config['language_model']['rms_norm_eps'],\n", - " vocab_size=model_config['language_model']['vocab_size'],\n", - " num_key_value_heads=model_config['language_model']['num_key_value_heads'],\n", - " rope_theta=model_config['language_model']['rope_theta'],\n", - " rope_traditional=model_config['language_model']['rope_traditional'],\n", - " rope_scaling=model_config['language_model']['rope_scaling'],\n", - " ),\n", - " vision_config=VisionConfig(\n", - " num_hidden_layers=model_config['vision_tower']['num_hidden_layers'],\n", - " hidden_size=model_config['vision_tower']['hidden_size'],\n", - " intermediate_size=model_config['vision_tower']['intermediate_size'],\n", - " num_attention_heads=model_config['vision_tower']['num_attention_heads'],\n", - " num_channels=model_config['vision_tower']['num_channels'],\n", - " image_size=model_config['vision_tower']['image_size'],\n", - " patch_size=model_config['vision_tower']['patch_size'],\n", - " ),\n", - " projection_config=ProjectionConfig(\n", - " in_features=model_config['multi_modal_projector']['in_features'],\n", - " out_features=model_config['multi_modal_projector']['out_features'],\n", - " )\n", - ")\n", - "\n", + "from llava import LlavaModel\n", + "mlx_model = LlavaModel.from_pretrained(path='mlx_model')\n", "\n", "\n", - "mlx_model = LlavaModel(llava_mlx_config)\n", "\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "mlx_model = LlavaModel.from_pretrained(path='mlx_model')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "2.6875" + "LlavaModel(\n", + " (vision_tower): CLIPVisionModel(\n", + " (patch_embedding): Conv2d(3, 1024, kernel_size=(14,), stride=(14, 14), padding=(0, 0), bias=False)\n", + " (pre_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (layers.0): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.1): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.2): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.3): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.4): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.5): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.6): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.7): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.8): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.9): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.10): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.11): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.12): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.13): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.14): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.15): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.16): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.17): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.18): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.19): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.20): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.21): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.22): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (layers.23): CLIPEncoderLayer(\n", + " (attention): MultiHeadAttention(\n", + " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", + " )\n", + " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", + " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", + " (dropout1): Dropout(p=0.0)\n", + " (dropout2): Dropout(p=0.0)\n", + " )\n", + " (post_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", + " )\n", + " (language_model): LlamaModel(\n", + " (model): Llama(\n", + " (embed_tokens): Embedding(32064, 4096)\n", + " (layers.0): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.1): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.2): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.3): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.4): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.5): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.6): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.7): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.8): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.9): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.10): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.11): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.12): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.13): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.14): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.15): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.16): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.17): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.18): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.19): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.20): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.21): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.22): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.23): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.24): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.25): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.26): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.27): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.28): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.29): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.30): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (layers.31): TransformerBlock(\n", + " (self_attn): Attention(\n", + " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", + " (rope): RoPE(128, traditional=False)\n", + " )\n", + " (mlp): MLP(\n", + " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", + " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " )\n", + " (norm): RMSNorm()\n", + " )\n", + " (lm_head): Linear(input_dims=4096, output_dims=32064, bias=False)\n", + " )\n", + " (multi_modal_projector): LlavaMultiModalProjector(\n", + " (linear_1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", + " (gelu): GELU()\n", + " (linear_2): Linear(input_dims=4096, output_dims=4096, bias=True)\n", + " )\n", + ")" ] }, - "execution_count": 6, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "11008 / 4096" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "mlx_model.load_weights('mlx_model/weights.npz')\n" + "mlx_model" ] }, { diff --git a/llava/llava.py b/llava/llava.py index a078775f7..34eabb9c6 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -97,10 +97,39 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in def from_pretrained(path: str): path = Path(path) - with open(path / "config.json", "r") as f: - config = json.load(f) - - model = LlavaModel(config) + with open(path / "mlx_config.json", "r") as f: + model_config = json.load(f) + + llava_mlx_config = LlaVAConfig( + llm_config=LLMConfig( + model_type='vicuna', + hidden_size=model_config['language_model']['hidden_size'], + num_hidden_layers=model_config['language_model']['num_hidden_layers'], + intermediate_size=model_config['language_model']['intermediate_size'], + num_attention_heads=model_config['language_model']['num_attention_heads'], + rms_norm_eps=model_config['language_model']['rms_norm_eps'], + vocab_size=model_config['language_model']['vocab_size'], + num_key_value_heads=model_config['language_model']['num_key_value_heads'], + rope_theta=model_config['language_model']['rope_theta'], + rope_traditional=model_config['language_model']['rope_traditional'], + rope_scaling=model_config['language_model']['rope_scaling'], + ), + vision_config=VisionConfig( + num_hidden_layers=model_config['vision_tower']['num_hidden_layers'], + hidden_size=model_config['vision_tower']['hidden_size'], + intermediate_size=model_config['vision_tower']['intermediate_size'], + num_attention_heads=model_config['vision_tower']['num_attention_heads'], + num_channels=model_config['vision_tower']['num_channels'], + image_size=model_config['vision_tower']['image_size'], + patch_size=model_config['vision_tower']['patch_size'], + ), + projection_config=ProjectionConfig( + in_features=model_config['multi_modal_projector']['in_features'], + out_features=model_config['multi_modal_projector']['out_features'], + ) + ) + + model = LlavaModel(llava_mlx_config) model.load_weights(str(path / "weights.npz")) return model From a1c6fe6468e18cd910f0de3a3da6ba59a39cf339 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Thu, 22 Feb 2024 08:16:45 -0500 Subject: [PATCH 09/34] add mlx config --- llava/mlx_model/mlx_config.json | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 llava/mlx_model/mlx_config.json diff --git a/llava/mlx_model/mlx_config.json b/llava/mlx_model/mlx_config.json new file mode 100644 index 000000000..482ec26ca --- /dev/null +++ b/llava/mlx_model/mlx_config.json @@ -0,0 +1,36 @@ +{ + "language_model": { + "hidden_size": 4096, + "num_hidden_layers": 32, + "intermediate_size": 11008, + "num_attention_heads": 32, + "rms_norm_eps": 1e-5, + "vocab_size": 32064, + "num_key_value_heads": 32, + "rope_theta": 0, + "rope_traditional": false, + "rope_scaling": null + }, + + "vision_tower": { + "num_hidden_layers": 24, + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "image_size": 336, + "patch_size": 14 + }, + + "multi_modal_projector": { + "in_features": 1024, + "out_features": 4096 + }, + + "vision_feature_layer": -2, + "vision_feature_selection_strategy": "default", + "image_token_index": 32000, + "pad_token_id": 32001, + "tie_word_embeddings": false, + "vocab_size": 32064 +} \ No newline at end of file From cec0639edabc889d0c42ccea35178fb5e82c627e Mon Sep 17 00:00:00 2001 From: anchen Date: Fri, 23 Feb 2024 20:43:40 +1100 Subject: [PATCH 10/34] feat: add image processor for llava processor --- llava/.gitignore | 163 +++++++++++++++++++++++++++++++- llava/download.py | 54 +++++++++++ llava/image_processor.py | 93 ++++++++++++++++++ llava/mlx_model/mlx_config.json | 36 ------- llava/processing_llava.py | 23 +++++ llava/test.py | 50 ++++++++++ 6 files changed, 382 insertions(+), 37 deletions(-) create mode 100644 llava/download.py create mode 100644 llava/image_processor.py delete mode 100644 llava/mlx_model/mlx_config.json create mode 100644 llava/processing_llava.py create mode 100644 llava/test.py diff --git a/llava/.gitignore b/llava/.gitignore index 857540df8..bc0a54fe8 100644 --- a/llava/.gitignore +++ b/llava/.gitignore @@ -1 +1,162 @@ -**mlx_model \ No newline at end of file +**mlx_model# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +models \ No newline at end of file diff --git a/llava/download.py b/llava/download.py new file mode 100644 index 000000000..e755896bb --- /dev/null +++ b/llava/download.py @@ -0,0 +1,54 @@ +import argparse +import os + +import requests +from tqdm import tqdm + + +def download_file(url, path): + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kbyte + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + + with open(path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + + progress_bar.close() + + +def download_model(model_name, destination_folder="models"): + # Define the base URL and headers for the Hugging Face API + base_url = f"https://huggingface.co/{model_name}/resolve/main" + headers = {"User-Agent": "Hugging Face Python"} + + # Send a GET request to the Hugging Face API to get a list of all files + response = requests.get( + f"https://huggingface.co/api/models/{model_name}", headers=headers + ) + response.raise_for_status() + + # Extract the list of files from the response JSON + files_to_download = [ + file["rfilename"] + for file in response.json()["siblings"] + if not file["rfilename"].endswith(".bin") + ] + + # Ensure the directory exists + os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True) + + # Download each file + for file in files_to_download: + print(f"Downloading {file}...") + download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name", type=str, help="Name of the model to download.") + args = parser.parse_args() + + download_model(args.model_name) diff --git a/llava/image_processor.py b/llava/image_processor.py new file mode 100644 index 000000000..5f5be8484 --- /dev/null +++ b/llava/image_processor.py @@ -0,0 +1,93 @@ +# Copyright © 2023-2024 Apple Inc. + +import json +from pathlib import Path +from typing import List, Tuple + +import mlx.core as mx +import numpy as np +from PIL.Image import Image + + +class CLIPImageProcessor: + """ + A simple port of + https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. + """ + + def __init__( + self, + crop_size: int = 336, + do_center_crop: bool = True, + do_normalize: bool = True, + do_resize: bool = True, + image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], + image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], + size: int = 336, + **kwargs + ) -> None: + self.crop_size = crop_size + self.do_center_crop = do_center_crop + self.do_normalize = do_normalize + self.do_resize = do_resize + self.image_mean = mx.array(image_mean) + self.image_std = mx.array(image_std) + self.size = size + + def __call__(self, images: List[Image]) -> mx.array: + return mx.concatenate( + [self._preprocess(image)[None] for image in images], axis=0 + ) + + def _preprocess(self, image: Image) -> mx.array: + if self.do_resize: + image = resize(image, self.size) + if self.do_center_crop: + image = center_crop(image, (self.crop_size, self.crop_size)) + image = mx.array(np.array(image)) + image = rescale(image) + if self.do_normalize: + image = normalize(image, self.image_mean, self.image_std) + return image + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + with open(path / "preprocessor_config.json", encoding="utf-8") as f: + config = json.load(f) + return CLIPImageProcessor(**config) + + +def resize(image: Image, short_size: int) -> Image: + """ + Resize so small size to short_size + """ + width, height = image.size + short = min(width, height) + long = max(width, height) + if short == short_size: + return image + new_short = short_size + new_long = int(short_size * long / short) + new_size = (new_short, new_long) if width <= height else (new_long, new_short) + return image.resize(new_size) + + +def center_crop(image: Image, size: Tuple[int, int]) -> Image: + if size[0] % 2 != 0 or size[1] % 2 != 0: + raise ValueError("Only even crop sizes supported.") + original_width, original_height = image.size + crop_height, crop_width = size + top = (original_height - crop_height) // 2 + bottom = top + crop_height + left = (original_width - crop_width) // 2 + right = left + crop_width + return image.crop((left, top, right, bottom)) + + +def rescale(image: mx.array) -> mx.array: + return image.astype(mx.float32) * (1 / 255.0) + + +def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: + return (image - mean) / std diff --git a/llava/mlx_model/mlx_config.json b/llava/mlx_model/mlx_config.json deleted file mode 100644 index 482ec26ca..000000000 --- a/llava/mlx_model/mlx_config.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "language_model": { - "hidden_size": 4096, - "num_hidden_layers": 32, - "intermediate_size": 11008, - "num_attention_heads": 32, - "rms_norm_eps": 1e-5, - "vocab_size": 32064, - "num_key_value_heads": 32, - "rope_theta": 0, - "rope_traditional": false, - "rope_scaling": null - }, - - "vision_tower": { - "num_hidden_layers": 24, - "hidden_size": 1024, - "intermediate_size": 4096, - "num_attention_heads": 16, - "num_channels": 3, - "image_size": 336, - "patch_size": 14 - }, - - "multi_modal_projector": { - "in_features": 1024, - "out_features": 4096 - }, - - "vision_feature_layer": -2, - "vision_feature_selection_strategy": "default", - "image_token_index": 32000, - "pad_token_id": 32001, - "tie_word_embeddings": false, - "vocab_size": 32064 -} \ No newline at end of file diff --git a/llava/processing_llava.py b/llava/processing_llava.py new file mode 100644 index 000000000..705d1ccf6 --- /dev/null +++ b/llava/processing_llava.py @@ -0,0 +1,23 @@ +from image_processor import CLIPImageProcessor + + +class LlavaProcessor: + def __init__(self, image_processor=None, tokenizer=None): + self.image_processor = CLIPImageProcessor() + self.tokenizer = tokenizer + + def __call__( + self, + text=None, + images=None, + padding=False, + truncation=None, + max_length=None, + return_tensors=None, + ): + if images is not None: + pixel_values = self.image_processor(images) + else: + pixel_values = None + + return {"pixel_values": pixel_values} diff --git a/llava/test.py b/llava/test.py new file mode 100644 index 000000000..2ad128f9f --- /dev/null +++ b/llava/test.py @@ -0,0 +1,50 @@ +import unittest + +import mlx.core as mx +import numpy as np +import requests +import torch +from PIL import Image +from processing_llava import LlavaProcessor +from transformers import AutoProcessor, LlavaForConditionalGeneration + +MLX_PATH = "models/llava-hf/llava-1.5-7b-hf" +HF_PATH = "models/llava-hf/llava-1.5-7b-hf" + + +def load_mlx_models(path): + processor = LlavaProcessor() + return processor, None + + +def load_hf_models(path): + processor = AutoProcessor.from_pretrained(path) + model = LlavaForConditionalGeneration.from_pretrained(path) + + return processor, model + + +class TestCLIP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH) + cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH) + + def test_processor(self): + prompt = "USER: \nWhat are these?\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + + hf_data = mx.array( + np.array( + self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"] + ) + ).transpose(0, 2, 3, 1) + + mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"] + + self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() From 4dd8bca0279a6548fe495df0eba34fa36f429c30 Mon Sep 17 00:00:00 2001 From: anchen Date: Sat, 24 Feb 2024 15:19:15 +1100 Subject: [PATCH 11/34] wip --- llava/clip.py | 409 +++++++++++++++++++++---------------------------- llava/llama.py | 58 +++---- llava/llava.py | 134 +++++++--------- llava/test.py | 29 ++-- 4 files changed, 267 insertions(+), 363 deletions(-) diff --git a/llava/clip.py b/llava/clip.py index 1568858a7..736307f31 100644 --- a/llava/clip.py +++ b/llava/clip.py @@ -1,65 +1,39 @@ -# Copyright © 2023-2024 Apple Inc. - +import glob +import inspect import json +import logging +import math from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Optional import mlx.core as mx import mlx.nn as nn -from mlx.core import linalg as LA -from mlx.nn.losses import cross_entropy -from mlx.utils import tree_flatten - - -@dataclass -class CLIPVisionOutput: - pooler_output: mx.array - last_hidden_state: mx.array - llava_hidden_state: mx.array - - -@dataclass -class CLIPTextOutput: - pooler_output: mx.array - last_hidden_state: mx.array - - -@dataclass -class CLIPModelOutput: - loss: Optional[mx.array] - text_embeds: Optional[mx.array] - image_embeds: Optional[mx.array] - text_model_output: CLIPTextOutput - vision_model_output: CLIPVisionOutput - - -@dataclass -class CLIPTextConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - max_position_embeddings: int - vocab_size: int - - -@dataclass -class CLIPVisionConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_channels: int - image_size: int - patch_size: int @dataclass -class CLIPConfig: - text_config: CLIPTextConfig - vision_config: CLIPVisionConfig - projection_dim: int +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) def quick_gelu(x: mx.array) -> mx.array: @@ -69,227 +43,196 @@ def quick_gelu(x: mx.array) -> mx.array: return x * mx.sigmoid(1.702 * x) -def clip_loss(logits: mx.array) -> mx.array: - N, M = logits.shape - caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean") - image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean") - return (caption_loss + image_loss) / 2.0 - +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() -class CLIPEncoderLayer(nn.TransformerEncoderLayer): - """The transformer encoder layer from CLIP.""" + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) - def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int): - super().__init__( - dims=hidden_dim, - mlp_dims=intermediate_dim, - num_heads=num_heads, - activation=quick_gelu, - norm_first=True, - ) - # Add biases to the attention projections - self.attention = nn.MultiHeadAttention( - hidden_dim, num_heads, bias=True) + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = quick_gelu + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x -class CLIPTextModel(nn.Module): - """Implements the text encoder transformer from CLIP.""" - def __init__(self, config: CLIPTextConfig): +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): super().__init__() - - self.token_embedding = nn.Embedding( - config.vocab_size, config.hidden_size) - self.position_embedding = mx.zeros( - (config.max_position_embeddings, config.hidden_size) + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True ) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] - self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - def _embed(self, x: mx.array) -> mx.array: - embeddings = self.token_embedding(x) - embeddings += self.position_embedding[: x.shape[1]] - return embeddings + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y - def __call__(self, x: mx.array) -> CLIPTextOutput: - B, N = x.shape - eot_tokens = mx.argmax(x, axis=-1) - x = self._embed(x) - mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype) - for l in self.layers: - x = l(x, mask) - last_hidden_state = self.final_layer_norm(x) - pooler_output = last_hidden_state[mx.arange(B), eot_tokens] - - return CLIPTextOutput( - pooler_output=pooler_output, last_hidden_state=last_hidden_state - ) +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] -class CLIPVisionModel(nn.Module): - """Implements the vision encoder transformer from CLIP.""" - def __init__(self, config: CLIPVisionConfig): +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size self.class_embedding = mx.zeros((config.hidden_size,)) + self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, - out_channels=config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - num_patches = (config.image_size // config.patch_size) ** 2 - num_positions = num_patches + 1 - self.position_embedding = mx.zeros((num_positions, config.hidden_size)) - self.pre_layernorm = nn.LayerNorm(config.hidden_size) - self.layers = [ - CLIPEncoderLayer( - config.hidden_size, config.intermediate_size, config.num_attention_heads - ) - for _ in range(config.num_hidden_layers) - ] - self.post_layernorm = nn.LayerNorm(config.hidden_size) - def _embed(self, x: mx.array) -> mx.array: + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] - # Patchify using conv: - # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] patch_embeddings = self.patch_embedding(x) - # [batch_size, num_patches, embed_dim] - patch_embeddings = mx.flatten( - patch_embeddings, start_axis=1, end_axis=2) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) embed_dim = patch_embeddings.shape[-1] - # Prepend embeddings - # [batch_size, 1, embed_dim] cls_embeddings = mx.broadcast_to( self.class_embedding, (batch_size, 1, embed_dim) ) - # [batch_size, num_patches + 1, embed_dim] embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) - # Add positional encoding - embeddings += self.position_embedding + embeddings += self.position_embedding.weight return embeddings - def __call__(self, x: mx.array) -> CLIPVisionOutput: - x = self._embed(x) - x = self.pre_layernorm(x) - - for l in self.layers: - x = l(x, mask=None) - - # Extract token embedding - pooler_output = self.post_layernorm(x[:, 0, :]) - - llava_hidden_state = x - return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x, llava_hidden_state=llava_hidden_state) - -class CLIPModel(nn.Module): - def __init__(self, config: CLIPConfig): - self.text_model = CLIPTextModel(config.text_config) - self.vision_model = CLIPVisionModel(config.vision_config) - - text_embed_dim = config.text_config.hidden_size - vision_embed_dim = config.vision_config.hidden_size - projection_dim = config.projection_dim +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) - self.visual_projection = nn.Linear( - vision_embed_dim, projection_dim, bias=False) - self.text_projection = nn.Linear( - text_embed_dim, projection_dim, bias=False) - self.logit_scale = mx.array(0.0) + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) - def get_text_features(self, x: mx.array) -> mx.array: - return self.text_projection(self.text_model(x).pooler_output) + encoder_states = (x,) if output_hidden_states else None - def get_image_features(self, x: mx.array) -> mx.array: - return self.visual_projection(self.vision_model(x).pooler_output) + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) - def __call__( - self, - input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None, - return_loss=False, - ) -> CLIPModelOutput: - if input_ids is not None: - text_model_output = self.text_model(input_ids) - text_embeds = self.text_projection(text_model_output.pooler_output) - text_embeds = text_embeds / \ - LA.norm(text_embeds, axis=-1, keepdims=True) - else: - text_embeds = None - text_model_output = None - - if pixel_values is not None: - vision_model_output = self.vision_model(pixel_values) - image_embeds = self.visual_projection( - vision_model_output.pooler_output) - image_embeds = image_embeds / \ - LA.norm(image_embeds, axis=-1, keepdims=True) - else: - image_embeds = None - vision_model_output = None - - if return_loss and (input_ids is None or pixel_values is None): - raise ValueError( - "Must provide text and image inputs to compute loss.") - - if return_loss: - logit_scale = mx.exp(self.logit_scale) - logits = (text_embeds @ image_embeds.T) * logit_scale - loss = clip_loss(logits) - else: - loss = None - - return CLIPModelOutput( - loss=loss, - text_embeds=text_embeds, - image_embeds=image_embeds, - vision_model_output=vision_model_output, - text_model_output=text_model_output, - ) + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states @staticmethod def from_pretrained(path: str): path = Path(path) with open(path / "config.json", "r") as fid: - config = json.load(fid) - - text_config = config["text_config"] - text_config = CLIPTextConfig( - num_hidden_layers=text_config["num_hidden_layers"], - hidden_size=text_config["hidden_size"], - intermediate_size=text_config["intermediate_size"], - num_attention_heads=text_config["num_attention_heads"], - max_position_embeddings=text_config["max_position_embeddings"], - vocab_size=text_config["vocab_size"], - ) + config_dict = json.load(fid) + vision_config = VisionConfig(**config_dict["vision_config"]) - vision_config = config["vision_config"] + model = ClipVisionModel(vision_config) - vision_config = CLIPVisionConfig( - num_hidden_layers=vision_config["num_hidden_layers"], - hidden_size=vision_config["hidden_size"], - intermediate_size=vision_config["intermediate_size"], - num_attention_heads=vision_config["num_attention_heads"], - num_channels=3, - image_size=vision_config["image_size"], - patch_size=vision_config["patch_size"], - ) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") - config = CLIPConfig( - text_config=text_config, - vision_config=vision_config, - projection_dim=config["projection_dim"], - ) - model = CLIPModel(config) - model.load_weights(str(path / "weights.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + model.load_weights(weights) return model + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] + # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/llava/llama.py b/llava/llama.py index ba5baeda5..242251a39 100644 --- a/llava/llama.py +++ b/llava/llama.py @@ -1,14 +1,25 @@ +import inspect from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -import inspect - @dataclass -class BaseModelArgs: +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + @classmethod def from_dict(cls, params): return cls( @@ -19,21 +30,6 @@ def from_dict(cls, params): } ) - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - def __post_init__(self): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads @@ -41,12 +37,10 @@ def __post_init__(self): if self.rope_scaling: required_keys = {"factor", "type"} if not all(key in self.rope_scaling for key in required_keys): - raise ValueError( - f"rope_scaling must contain keys {required_keys}") + raise ValueError(f"rope_scaling must contain keys {required_keys}") if self.rope_scaling["type"] != "linear": - raise ValueError( - "rope_scaling 'type' currently only supports 'linear'") + raise ValueError("rope_scaling 'type' currently only supports 'linear'") class RMSNorm(nn.Module): @@ -64,7 +58,7 @@ def __call__(self, x): class Attention(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() dim = args.hidden_size @@ -106,8 +100,7 @@ def __call__( # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape( - B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.repeats > 1: keys = mx.repeat(keys, self.repeats, axis=1) @@ -126,8 +119,7 @@ def __call__( scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores += mask - scores = mx.softmax(scores.astype(mx.float32), - axis=-1).astype(scores.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -144,15 +136,14 @@ def __call__(self, x) -> mx.array: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.args = args def __call__( @@ -169,7 +160,7 @@ def __call__( class Llama(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.args = args self.vocab_size = args.vocab_size @@ -190,8 +181,7 @@ def __call__( mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask( - h.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) if cache is None: @@ -204,7 +194,7 @@ def __call__( class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TextConfig): super().__init__() self.model_type = args.model_type self.model = Llama(args) diff --git a/llava/llava.py b/llava/llava.py index 34eabb9c6..4edfbeaca 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -1,56 +1,41 @@ -from clip import CLIPVisionModel -from llama import LlamaModel -from pathlib import Path +import glob +import inspect import json -import mlx.nn as nn -import mlx.core as mx -from typing import Any, Optional, Dict, Union - - +import logging from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional +import mlx.core as mx +import mlx.nn as nn +from llama import LlamaModel, TextConfig -@dataclass -class VisionConfig: - num_hidden_layers: int - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_channels: int - image_size: int - patch_size: int - - -@dataclass -class LLMConfig: - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - -@dataclass -class ProjectionConfig: - in_features: int - out_features: int +from clip import ClipVisionModel, VisionConfig @dataclass class LlaVAConfig: - llm_config: Any + text_config: TextConfig vision_config: VisionConfig - projection_config: ProjectionConfig + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) class LlavaMultiModalProjector(nn.Module): - def __init__(self, config: Any): + def __init__(self, config: LlaVAConfig): super().__init__() self.linear_1 = nn.Linear(config.in_features, config.out_features) self.gelu = nn.GELU() @@ -65,14 +50,19 @@ def forward(self, x: mx.array) -> mx.array: class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): - self.vision_tower = CLIPVisionModel(config=config.vision_config) - self.language_model = LlamaModel(args=config.llm_config) + self.vision_tower = ClipVisionModel( + config=VisionConfig.from_dict(config.vision_config) + ) + self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config)) self.multi_modal_projector = LlavaMultiModalProjector( - config=config.projection_config) + config=config.projection_config + ) - def __call__(self, - input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None): + def __call__( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): # TODO: add the forward pass if pixel_values is not None and input_ids.shape[1] != 1: @@ -81,10 +71,11 @@ def __call__(self, # TODO: this is not the correct output layer, but it's a placeholder selected_image_feature = image_outputs.pooler_output - image_features = self.multi_modal_projector( - selected_image_feature) + image_features = self.multi_modal_projector(selected_image_feature) - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, labels + ): # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 special_image_token_mask = input_ids == self.config.special_tokens.image @@ -97,39 +88,20 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in def from_pretrained(path: str): path = Path(path) - with open(path / "mlx_config.json", "r") as f: + with open(path / "config.json", "r") as f: model_config = json.load(f) - llava_mlx_config = LlaVAConfig( - llm_config=LLMConfig( - model_type='vicuna', - hidden_size=model_config['language_model']['hidden_size'], - num_hidden_layers=model_config['language_model']['num_hidden_layers'], - intermediate_size=model_config['language_model']['intermediate_size'], - num_attention_heads=model_config['language_model']['num_attention_heads'], - rms_norm_eps=model_config['language_model']['rms_norm_eps'], - vocab_size=model_config['language_model']['vocab_size'], - num_key_value_heads=model_config['language_model']['num_key_value_heads'], - rope_theta=model_config['language_model']['rope_theta'], - rope_traditional=model_config['language_model']['rope_traditional'], - rope_scaling=model_config['language_model']['rope_scaling'], - ), - vision_config=VisionConfig( - num_hidden_layers=model_config['vision_tower']['num_hidden_layers'], - hidden_size=model_config['vision_tower']['hidden_size'], - intermediate_size=model_config['vision_tower']['intermediate_size'], - num_attention_heads=model_config['vision_tower']['num_attention_heads'], - num_channels=model_config['vision_tower']['num_channels'], - image_size=model_config['vision_tower']['image_size'], - patch_size=model_config['vision_tower']['patch_size'], - ), - projection_config=ProjectionConfig( - in_features=model_config['multi_modal_projector']['in_features'], - out_features=model_config['multi_modal_projector']['out_features'], - ) - ) + model_config = LlaVAConfig.from_dict(model_config) + model = LlavaModel(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {path}") + raise FileNotFoundError(f"No safetensors found in {path}") - model = LlavaModel(llava_mlx_config) - model.load_weights(str(path / "weights.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + weights = ClipVisionModel.sanitize(weights) + model.load_weights(list(weights.items())) return model diff --git a/llava/test.py b/llava/test.py index 2ad128f9f..0d693677e 100644 --- a/llava/test.py +++ b/llava/test.py @@ -3,32 +3,31 @@ import mlx.core as mx import numpy as np import requests -import torch from PIL import Image from processing_llava import LlavaProcessor from transformers import AutoProcessor, LlavaForConditionalGeneration -MLX_PATH = "models/llava-hf/llava-1.5-7b-hf" -HF_PATH = "models/llava-hf/llava-1.5-7b-hf" +from llava import LlavaModel + +MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" def load_mlx_models(path): - processor = LlavaProcessor() - return processor, None + model = LlavaModel.from_pretrained(path) + return model def load_hf_models(path): - processor = AutoProcessor.from_pretrained(path) model = LlavaForConditionalGeneration.from_pretrained(path) - - return processor, model + return model class TestCLIP(unittest.TestCase): @classmethod def setUpClass(cls): - cls.mx_proc, cls.mx_llava = load_mlx_models(MLX_PATH) - cls.hf_proc, cls.hf_llava = load_hf_models(HF_PATH) + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) def test_processor(self): prompt = "USER: \nWhat are these?\nASSISTANT:" @@ -36,12 +35,12 @@ def test_processor(self): raw_image = Image.open(requests.get(image_file, stream=True).raw) hf_data = mx.array( - np.array( - self.hf_proc(prompt, raw_image, return_tensors="pt")["pixel_values"] - ) - ).transpose(0, 2, 3, 1) + self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] + ) - mx_data = self.mx_proc(prompt, [raw_image])["pixel_values"] + mx_data = mx.array( + self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] + ) self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) From c4ea94fdec3ee62a38b51f7408def4e129172561 Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 00:17:13 +1100 Subject: [PATCH 12/34] feat: llava working example --- llava/Local LLava.ipynb | 1121 ------------------------------- llava/config.py | 37 - llava/convert.py | 87 --- llava/download.py | 54 -- llava/generate.py | 58 ++ llava/image_processor.py | 93 --- llava/{llama.py => language.py} | 12 +- llava/llava.py | 112 ++- llava/processing_llava.py | 23 - llava/test.py | 103 ++- llava/utils.py | 70 -- llava/{clip.py => vision.py} | 30 +- 12 files changed, 255 insertions(+), 1545 deletions(-) delete mode 100644 llava/Local LLava.ipynb delete mode 100644 llava/config.py delete mode 100644 llava/convert.py delete mode 100644 llava/download.py create mode 100644 llava/generate.py delete mode 100644 llava/image_processor.py rename llava/{llama.py => language.py} (95%) delete mode 100644 llava/processing_llava.py delete mode 100644 llava/utils.py rename llava/{clip.py => vision.py} (90%) diff --git a/llava/Local LLava.ipynb b/llava/Local LLava.ipynb deleted file mode 100644 index 69c760103..000000000 --- a/llava/Local LLava.ipynb +++ /dev/null @@ -1,1121 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import shutil\n", - "from pathlib import Path\n", - "import os\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "mlx_path = Path('mlx_model')\n", - "\n", - "if not os.path.exists(mlx_path):\n", - " os.makedirs(mlx_path)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/noahkasmanoff/anaconda3/envs/mlx/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 214177.23it/s]\n", - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" - ] - } - ], - "source": [ - "import mlx.core as mx\n", - "from convert import get_model_path, fetch_from_hub, hf_repo\n", - "\n", - "\n", - "model_path = get_model_path(hf_repo)\n", - "model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO] Converting\n", - "[INFO] Saving\n" - ] - } - ], - "source": [ - "from utils import map_weights, should_keep_weight\n", - "do_convert = True\n", - "if do_convert:\n", - "\n", - " print(\"[INFO] Converting\")\n", - " mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())\n", - " mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}\n", - " print(\"[INFO] Saving\")\n", - " mx.savez(str(mlx_path / \"weights.npz\"), **mlx_weights)\n", - " for fn in [\"config.json\", \"merges.txt\", \"vocab.json\", \"preprocessor_config.json\"]:\n", - " if fn in os.listdir(model_path):\n", - " shutil.copyfile(\n", - " str(model_path / f\"{fn}\"),\n", - " str(mlx_path / f\"{fn}\"),\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from llava import LlavaModel\n", - "mlx_model = LlavaModel.from_pretrained(path='mlx_model')\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "mlx_model = LlavaModel.from_pretrained(path='mlx_model')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "LlavaModel(\n", - " (vision_tower): CLIPVisionModel(\n", - " (patch_embedding): Conv2d(3, 1024, kernel_size=(14,), stride=(14, 14), padding=(0, 0), bias=False)\n", - " (pre_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (layers.0): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.1): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.2): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.3): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.4): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.5): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.6): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.7): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.8): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.9): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.10): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.11): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.12): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.13): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.14): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.15): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.16): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.17): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.18): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.19): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.20): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.21): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.22): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (layers.23): CLIPEncoderLayer(\n", - " (attention): MultiHeadAttention(\n", - " (query_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (key_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (value_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " (out_proj): Linear(input_dims=1024, output_dims=1024, bias=True)\n", - " )\n", - " (ln1): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (ln2): LayerNorm(1024, eps=1e-05, affine=True)\n", - " (linear1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (linear2): Linear(input_dims=4096, output_dims=1024, bias=True)\n", - " (dropout1): Dropout(p=0.0)\n", - " (dropout2): Dropout(p=0.0)\n", - " )\n", - " (post_layernorm): LayerNorm(1024, eps=1e-05, affine=True)\n", - " )\n", - " (language_model): LlamaModel(\n", - " (model): Llama(\n", - " (embed_tokens): Embedding(32064, 4096)\n", - " (layers.0): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.1): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.2): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.3): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.4): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.5): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.6): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.7): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.8): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.9): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.10): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.11): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.12): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.13): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.14): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.15): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.16): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.17): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.18): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.19): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.20): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.21): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.22): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.23): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.24): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.25): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.26): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.27): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.28): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.29): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.30): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (layers.31): TransformerBlock(\n", - " (self_attn): Attention(\n", - " (q_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (k_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (v_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (o_proj): Linear(input_dims=4096, output_dims=4096, bias=False)\n", - " (rope): RoPE(128, traditional=False)\n", - " )\n", - " (mlp): MLP(\n", - " (gate_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " (down_proj): Linear(input_dims=11008, output_dims=4096, bias=False)\n", - " (up_proj): Linear(input_dims=4096, output_dims=11008, bias=False)\n", - " )\n", - " (input_layernorm): RMSNorm()\n", - " (post_attention_layernorm): RMSNorm()\n", - " )\n", - " (norm): RMSNorm()\n", - " )\n", - " (lm_head): Linear(input_dims=4096, output_dims=32064, bias=False)\n", - " )\n", - " (multi_modal_projector): LlavaMultiModalProjector(\n", - " (linear_1): Linear(input_dims=1024, output_dims=4096, bias=True)\n", - " (gelu): GELU()\n", - " (linear_2): Linear(input_dims=4096, output_dims=4096, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mlx_model" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" - ] - } - ], - "source": [ - "# Now that model weights are loaded in, now we can try and run inference code / set that up.\n", - "\n", - "# load the processor\n", - "from transformers import AutoProcessor\n", - "import requests\n", - "from PIL import Image\n", - "processor = AutoProcessor.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n", - "\n", - "prompt = \"\\nUSER: What's the content of the image?\\nASSISTANT:\"\n", - "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", - "image = Image.open(requests.get(url, stream=True).raw)\n", - "\n", - "inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "input_ids = mx.array(inputs[\"input_ids\"].numpy())\n", - "pixel_values = mx.array(inputs[\"pixel_values\"].numpy())\n" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "vision_model_output = mlx_model.vision_tower(pixel_values.transpose(0,2,3,1))" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1, 577, 1024)" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CLIPVisionOutput(pooler_output=array([[-0.721487, -0.476275, 0.0173661, ..., 0.190072, -1.71528, 1.36224]], dtype=float32), last_hidden_state=array([[[-0.333623, -0.269844, 0.025435, ..., -0.0516554, -0.729696, 0.542679],\n", - " [0.208684, 0.92752, 0.0233985, ..., 1.59934, -0.024813, 0.879629],\n", - " [0.550235, 0.45201, 0.80935, ..., 1.63056, -0.37727, 0.699322],\n", - " ...,\n", - " [0.740987, 0.445616, 0.893172, ..., 0.523529, 0.0230118, -0.457155],\n", - " [0.49297, 0.0680847, 0.79401, ..., 0.476083, 0.274526, -0.284749],\n", - " [-0.0411091, 0.290756, 0.518906, ..., 0.242572, 0.40785, 0.420446]]], dtype=float32))" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "vision_model_output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mlx", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/llava/config.py b/llava/config.py deleted file mode 100644 index f8617ac14..000000000 --- a/llava/config.py +++ /dev/null @@ -1,37 +0,0 @@ -model_config = { - 'language_model': { - 'hidden_size': 4096, - 'num_hidden_layers': 32, - 'intermediate_size': 11008, - 'num_attention_heads': 32, - 'rms_norm_eps': 1e-5, - 'vocab_size': 32000, - 'num_key_value_heads': 32, - 'rope_theta': 0, - 'rope_traditional': False, - 'rope_scaling': None}, - - 'vision_tower': { - 'num_hidden_layers': 24, - 'hidden_size': 1024, - 'intermediate_size': 4096, - 'num_attention_heads': 16, - 'num_channels': 3, - 'image_size': 336, - 'patch_size': 14 - }, - - 'multi_modal_projector': { - 'in_features': 1024, - 'out_features': 4096 - }, - - 'vision_feature_layer': -2, - 'vision_feature_selection_strategy': 'default', - 'image_token_index': 32000, - 'pad_token_id': 32001, - 'tie_word_embeddings': False, - 'vocab_size': 32064, # TODO: confirm this value - - -} diff --git a/llava/convert.py b/llava/convert.py deleted file mode 100644 index af9973850..000000000 --- a/llava/convert.py +++ /dev/null @@ -1,87 +0,0 @@ - -from safetensors.torch import load_file -from pathlib import Path -import glob -import json -import logging -import mlx.nn as nn -from huggingface_hub import snapshot_download -from typing import Dict, Tuple -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer - - -hf_repo = "llava-hf/llava-1.5-7b-hf" - - -def get_model_path(path_or_hf_repo: str) -> Path: - """ - Ensures the model is available locally. If the path does not exist locally, - it is downloaded from the Hugging Face Hub. - - Args: - path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. - - Returns: - Path: The path to the model. - """ - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - return model_path - - -def load_model(model_path: Path) -> nn.Module: - """ - Load and initialize the model from a given path. - - Args: - model_path (Path): The path to load the model from. - - Returns: - nn.Module: The loaded and initialized model. - - Raises: - FileNotFoundError: If the weight files (.safetensors) are not found. - ValueError: If the model class or args class are not found or cannot be instantiated. - """ - try: - with open(model_path / "config.json", "r") as f: - config = json.load(f) - except FileNotFoundError: - logging.error(f"Config file not found in {model_path}") - raise - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if not weight_files: - logging.error(f"No safetensors found in {model_path}") - raise FileNotFoundError(f"No safetensors found in {model_path}") - - weights = {} - for wf in weight_files: - weights.update(load_file(wf)) - - return config, weights, weight_files - - -def fetch_from_hub( - model_path: Path, -) -> Tuple[Dict, dict, PreTrainedTokenizer]: - model_config, model_weights, model_weight_files = load_model(model_path) - - config = AutoConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained( - model_path) # TODO: should this be the processor? - - # TODO: replace outputs with the model alone once conversion is complete - return model_config, model_weights, model_weight_files, config, tokenizer diff --git a/llava/download.py b/llava/download.py deleted file mode 100644 index e755896bb..000000000 --- a/llava/download.py +++ /dev/null @@ -1,54 +0,0 @@ -import argparse -import os - -import requests -from tqdm import tqdm - - -def download_file(url, path): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 Kbyte - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - - with open(path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - - progress_bar.close() - - -def download_model(model_name, destination_folder="models"): - # Define the base URL and headers for the Hugging Face API - base_url = f"https://huggingface.co/{model_name}/resolve/main" - headers = {"User-Agent": "Hugging Face Python"} - - # Send a GET request to the Hugging Face API to get a list of all files - response = requests.get( - f"https://huggingface.co/api/models/{model_name}", headers=headers - ) - response.raise_for_status() - - # Extract the list of files from the response JSON - files_to_download = [ - file["rfilename"] - for file in response.json()["siblings"] - if not file["rfilename"].endswith(".bin") - ] - - # Ensure the directory exists - os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True) - - # Download each file - for file in files_to_download: - print(f"Downloading {file}...") - download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model_name", type=str, help="Name of the model to download.") - args = parser.parse_args() - - download_model(args.model_name) diff --git a/llava/generate.py b/llava/generate.py new file mode 100644 index 000000000..a51fe2cc9 --- /dev/null +++ b/llava/generate.py @@ -0,0 +1,58 @@ +import mlx.core as mx +import mlx.nn as nn +import requests +from PIL import Image +from transformers import AutoProcessor + +from llava import LlavaModel + +MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" + +prompt = "USER: \nWhat are these?\nASSISTANT:" +image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" +raw_image = Image.open(requests.get(image_file, stream=True).raw) + + +processor = AutoProcessor.from_pretrained(MODEL_PATH) +model = LlavaModel.from_pretrained(MODEL_PATH) + +values = processor(prompt, raw_image, return_tensors="np") +pixel_values = mx.array(values["pixel_values"]) +input_ids = mx.array(values["input_ids"]) + +input_embeds = model(input_ids, pixel_values) +max_tokens = 100 +temperature = 0.3 + + +def sample(logits, temp=0.0): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + +def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None): + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + + y = sample(logits, temp=temp) + token = y.item() + + yield token + + +logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds) +logits = logits[:, -1, :] +y = sample(logits, temp=temperature) +tokens = [y.item()] +for token, _ in zip( + generate(y, model.language_model, temperature, cache=cache), + range(max_tokens), +): + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + +print(processor.tokenizer.decode(tokens)) diff --git a/llava/image_processor.py b/llava/image_processor.py deleted file mode 100644 index 5f5be8484..000000000 --- a/llava/image_processor.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import json -from pathlib import Path -from typing import List, Tuple - -import mlx.core as mx -import numpy as np -from PIL.Image import Image - - -class CLIPImageProcessor: - """ - A simple port of - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. - """ - - def __init__( - self, - crop_size: int = 336, - do_center_crop: bool = True, - do_normalize: bool = True, - do_resize: bool = True, - image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], - image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], - size: int = 336, - **kwargs - ) -> None: - self.crop_size = crop_size - self.do_center_crop = do_center_crop - self.do_normalize = do_normalize - self.do_resize = do_resize - self.image_mean = mx.array(image_mean) - self.image_std = mx.array(image_std) - self.size = size - - def __call__(self, images: List[Image]) -> mx.array: - return mx.concatenate( - [self._preprocess(image)[None] for image in images], axis=0 - ) - - def _preprocess(self, image: Image) -> mx.array: - if self.do_resize: - image = resize(image, self.size) - if self.do_center_crop: - image = center_crop(image, (self.crop_size, self.crop_size)) - image = mx.array(np.array(image)) - image = rescale(image) - if self.do_normalize: - image = normalize(image, self.image_mean, self.image_std) - return image - - @staticmethod - def from_pretrained(path: str): - path = Path(path) - with open(path / "preprocessor_config.json", encoding="utf-8") as f: - config = json.load(f) - return CLIPImageProcessor(**config) - - -def resize(image: Image, short_size: int) -> Image: - """ - Resize so small size to short_size - """ - width, height = image.size - short = min(width, height) - long = max(width, height) - if short == short_size: - return image - new_short = short_size - new_long = int(short_size * long / short) - new_size = (new_short, new_long) if width <= height else (new_long, new_short) - return image.resize(new_size) - - -def center_crop(image: Image, size: Tuple[int, int]) -> Image: - if size[0] % 2 != 0 or size[1] % 2 != 0: - raise ValueError("Only even crop sizes supported.") - original_width, original_height = image.size - crop_height, crop_width = size - top = (original_height - crop_height) // 2 - bottom = top + crop_height - left = (original_width - crop_width) // 2 - right = left + crop_width - return image.crop((left, top, right, bottom)) - - -def rescale(image: mx.array) -> mx.array: - return image.astype(mx.float32) * (1 / 255.0) - - -def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: - return (image - mean) / std diff --git a/llava/llama.py b/llava/language.py similarity index 95% rename from llava/llama.py rename to llava/language.py index 242251a39..bb9078155 100644 --- a/llava/llama.py +++ b/llava/language.py @@ -176,8 +176,13 @@ def __call__( self, inputs: mx.array, cache=None, + inputs_embeds=None, ): - h = self.embed_tokens(inputs) + # for passing merged input embeddings + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds mask = None if h.shape[1] > 1: @@ -193,7 +198,7 @@ def __call__( return self.norm(h), cache -class LlamaModel(nn.Module): +class LanguageModel(nn.Module): def __init__(self, args: TextConfig): super().__init__() self.model_type = args.model_type @@ -204,8 +209,9 @@ def __call__( self, inputs: mx.array, cache=None, + inputs_embeds=None, ): - out, cache = self.model(inputs, cache) + out, cache = self.model(inputs, cache, inputs_embeds) return self.lm_head(out), cache @staticmethod diff --git a/llava/llava.py b/llava/llava.py index 4edfbeaca..01e76122a 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -8,9 +8,9 @@ import mlx.core as mx import mlx.nn as nn -from llama import LlamaModel, TextConfig - -from clip import ClipVisionModel, VisionConfig +import numpy as np +from language import LanguageModel, TextConfig +from vision import VisionConfig, VisionModel @dataclass @@ -37,11 +37,15 @@ def from_dict(cls, params): class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): super().__init__() - self.linear_1 = nn.Linear(config.in_features, config.out_features) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) self.gelu = nn.GELU() - self.linear_2 = nn.Linear(config.out_features, config.out_features) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) - def forward(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array) -> mx.array: x = self.linear_1(x) x = self.gelu(x) x = self.linear_2(x) @@ -50,39 +54,81 @@ def forward(self, x: mx.array) -> mx.array: class LlavaModel(nn.Module): def __init__(self, config: LlaVAConfig): - self.vision_tower = ClipVisionModel( - config=VisionConfig.from_dict(config.vision_config) - ) - self.language_model = LlamaModel(args=TextConfig.from_dict(config.text_config)) - self.multi_modal_projector = LlavaMultiModalProjector( - config=config.projection_config - ) + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy def __call__( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, ): - # TODO: add the forward pass - - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower(pixel_values) + if pixel_values is None: + return self.language_model(input_ids) - # TODO: this is not the correct output layer, but it's a placeholder - selected_image_feature = image_outputs.pooler_output + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + _, _, hidden_states = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + ) + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature) + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) - image_features = self.multi_modal_projector(selected_image_feature) + return final_inputs_embeds def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids, attention_mask, labels + self, image_features, inputs_embeds, input_ids ): - # TODO: https://github.com/huggingface/transformers/blob/4f09d0fd888dbf2660313f9715992822acfb99ce/src/transformers/models/llava/modeling_llava.py#L279 + image_features = np.array(image_features) + inputs_embeds = np.array(inputs_embeds) + input_ids = np.array(input_ids) + + _, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) + max_embed_dim = ( + np.max(num_special_image_tokens) * (num_image_patches - 1) + ) + sequence_length + + non_image_indices = np.where(input_ids != self.config.image_token_index) + + new_token_positions = ( + np.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), axis=-1) + - 1 + ) + text_to_overwrite = new_token_positions[non_image_indices] + + final_embedding = np.zeros( + (batch_size, max_embed_dim, embed_dim), dtype=inputs_embeds.dtype + ) - special_image_token_mask = input_ids == self.config.special_tokens.image + final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[ + non_image_indices + ] - num_image_tokens = special_image_token_mask.sum() + image_to_overwrite = np.all(final_embedding == 0, axis=-1) + reshaped_image_features = image_features.reshape(-1, embed_dim) + final_embedding[image_to_overwrite, :] = reshaped_image_features[ + : np.sum(image_to_overwrite) + ] - pass + return mx.array(final_embedding) @staticmethod def from_pretrained(path: str): @@ -92,6 +138,15 @@ def from_pretrained(path: str): model_config = json.load(f) model_config = LlaVAConfig.from_dict(model_config) + + if isinstance(model_config.vision_config, dict): + model_config.vision_config = VisionConfig.from_dict( + model_config.vision_config + ) + + if isinstance(model_config.text_config, dict): + model_config.text_config = TextConfig.from_dict(model_config.text_config) + model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) if not weight_files: @@ -102,6 +157,11 @@ def from_pretrained(path: str): for wf in weight_files: weights.update(mx.load(wf)) - weights = ClipVisionModel.sanitize(weights) + if hasattr(VisionModel, "sanitize"): + weights = VisionModel.sanitize(weights) + + if hasattr(VisionModel, "sanitize"): + weights = LanguageModel.sanitize(weights) + model.load_weights(list(weights.items())) return model diff --git a/llava/processing_llava.py b/llava/processing_llava.py deleted file mode 100644 index 705d1ccf6..000000000 --- a/llava/processing_llava.py +++ /dev/null @@ -1,23 +0,0 @@ -from image_processor import CLIPImageProcessor - - -class LlavaProcessor: - def __init__(self, image_processor=None, tokenizer=None): - self.image_processor = CLIPImageProcessor() - self.tokenizer = tokenizer - - def __call__( - self, - text=None, - images=None, - padding=False, - truncation=None, - max_length=None, - return_tensors=None, - ): - if images is not None: - pixel_values = self.image_processor(images) - else: - pixel_values = None - - return {"pixel_values": pixel_values} diff --git a/llava/test.py b/llava/test.py index 0d693677e..4cca0c9af 100644 --- a/llava/test.py +++ b/llava/test.py @@ -3,8 +3,8 @@ import mlx.core as mx import numpy as np import requests +import torch from PIL import Image -from processing_llava import LlavaProcessor from transformers import AutoProcessor, LlavaForConditionalGeneration from llava import LlavaModel @@ -14,11 +14,13 @@ def load_mlx_models(path): model = LlavaModel.from_pretrained(path) + model.eval() return model def load_hf_models(path): model = LlavaForConditionalGeneration.from_pretrained(path) + model.eval() return model @@ -29,20 +31,103 @@ def setUpClass(cls): cls.hf_llava = load_hf_models(MODEL_PATH) cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) - def test_processor(self): + def test_image_features(self): prompt = "USER: \nWhat are these?\nASSISTANT:" image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[ + "pixel_values" + ] - hf_data = mx.array( - self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] - ) + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) - mx_data = mx.array( - self.proc(prompt, raw_image, return_tensors="np")["pixel_values"] - ) + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) - self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5)) + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + + self.assertTrue( + mx.allclose( + mx_image_features, + mx.array(hf_image_features.numpy()), + atol=1e-2, + ) + ) + + def test_merge_input_ids_with_image_features(self): + prompt = "USER: \nWhat are these?\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + vision_feature_layer = -2 + with torch.no_grad(): + values = self.proc(prompt, raw_image, return_tensors="pt") + pixel_values = values["pixel_values"] + input_ids = values["input_ids"] + + hf_pixel_values = pixel_values + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + + _, _, hidden_states = self.mx_llava.vision_tower( + mx_pixel_values, + output_hidden_states=True, + ) + mx_input_ids = mx.array(input_ids.numpy()) + mx_elected_image_feature = hidden_states[vision_feature_layer] + mx_image_features = self.mx_llava.multi_modal_projector( + mx_elected_image_feature + ) + mx_inputs_embeds = self.mx_llava.language_model.model.embed_tokens( + mx_input_ids + ) + mx_final_embedding = self.mx_llava._merge_input_ids_with_image_features( + mx_image_features, mx_inputs_embeds, mx_input_ids + ) + + hf_image_outputs = self.hf_llava.vision_tower( + hf_pixel_values, output_hidden_states=True + ) + hf_elected_image_feature = hf_image_outputs.hidden_states[ + vision_feature_layer + ] + hf_image_features = self.hf_llava.multi_modal_projector( + hf_elected_image_feature + ) + hf_inputs_embeds = self.hf_llava.get_input_embeddings()(input_ids) + hf_final_embedding, _, _, _ = ( + self.hf_llava._merge_input_ids_with_image_features( + hf_image_features, + hf_inputs_embeds, + input_ids, + attention_mask=input_ids, + labels=torch.ones_like(input_ids), + ) + ) + + self.assertTrue( + mx.allclose( + mx_final_embedding, + mx.array(hf_final_embedding.numpy()), + atol=1e-1, + ) + ) if __name__ == "__main__": diff --git a/llava/utils.py b/llava/utils.py deleted file mode 100644 index 62c5d0ef1..000000000 --- a/llava/utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import mlx.core as mx -import torch -from typing import Tuple - - -def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: - # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss - a = a.to(torch.float32) if dtype == "bfloat16" else a.to( - getattr(torch, dtype)) - return mx.array(a.numpy(), getattr(mx, dtype)) - - -def should_keep_weight(key: str): - return not ("position_ids" in key) - - -def map_vision_tower_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - key = key.replace("embeddings.", "") - key = key.replace("encoder.", "") - key = key.replace("position_embedding.weight", "position_embedding") - - key = key.replace('vision_model.', '') - - # Map attention layers - if "self_attn." in key: - key = key.replace("self_attn.", "attention.") - if "q_proj." in key: - key = key.replace("q_proj.", "query_proj.") - if "k_proj." in key: - key = key.replace("k_proj.", "key_proj.") - if "v_proj." in key: - key = key.replace("v_proj.", "value_proj.") - if "layer_norm1." in key: - key = key.replace("layer_norm1.", "ln1.") - if "layer_norm2." in key: - key = key.replace("layer_norm2.", "ln2.") - # Map ffn layers - if "mlp.fc1" in key: - key = key.replace("mlp.fc1", "linear1") - if "mlp.fc2" in key: - key = key.replace("mlp.fc2", "linear2") - # Fix layernorm typo - if "pre_layrnorm" in key: - # Fix typo in weights :) - key = key.replace("pre_layrnorm", "pre_layernorm") - if "patch_embedding.weight" in key: - # Initially, value: [out_channels, in_channels, kH, KW]. - # We want [out_channels, kH, KW, in_channels] - value = value.permute(0, 2, 3, 1) - return (key, value) - - -def map_language_model_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - return (key, value) - - -def map_multi_modal_projector_weights(key: str, value: torch.Tensor) -> Tuple[str, torch.Tensor]: - return (key, value) - - -def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: - - if 'vision_tower' in key: - key, value = map_vision_tower_weights(key, value) - elif 'language_model' in key: - key, value = map_language_model_weights(key, value) - elif 'multi_modal_projector' in key: - key, value = map_multi_modal_projector_weights(key, value) - - return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) diff --git a/llava/clip.py b/llava/vision.py similarity index 90% rename from llava/clip.py rename to llava/vision.py index 736307f31..ed7f7a46e 100644 --- a/llava/clip.py +++ b/llava/vision.py @@ -4,7 +4,6 @@ import logging import math from dataclasses import dataclass -from pathlib import Path from typing import Optional import mlx.core as mx @@ -197,29 +196,16 @@ def __call__( pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states - @staticmethod - def from_pretrained(path: str): - path = Path(path) - - with open(path / "config.json", "r") as fid: - config_dict = json.load(fid) - vision_config = VisionConfig(**config_dict["vision_config"]) - - model = ClipVisionModel(vision_config) - - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - logging.error(f"No safetensors found in {path}") - raise FileNotFoundError(f"No safetensors found in {path}") - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.vision_model = ClipVisionModel(config) - weights = model.sanitize(weights) - model.load_weights(list(weights.items())) - model.load_weights(weights) - return model + def __call__( + self, x: mx.array, output_hidden_states: Optional[bool] = None + ) -> mx.array: + return self.vision_model(x, output_hidden_states) @staticmethod def sanitize(weights): From b9aeadea28929f712213fa9db8327d05c84da42a Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 00:40:40 +1100 Subject: [PATCH 13/34] chore: refactor generate script --- llava/generate.py | 132 ++++++++++++++++++++++++++++++++++------------ llava/llava.py | 2 +- 2 files changed, 100 insertions(+), 34 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index a51fe2cc9..df92c97f7 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -1,3 +1,6 @@ +import argparse +import os + import mlx.core as mx import mlx.nn as nn import requests @@ -6,53 +9,116 @@ from llava import LlavaModel -MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" -prompt = "USER: \nWhat are these?\nASSISTANT:" -image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" -raw_image = Image.open(requests.get(image_file, stream=True).raw) +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Generate text from an image using a model." + ) + parser.add_argument( + "--model", + type=str, + default="models/llava-hf/llava-1.5-7b-hf", + help="Path to the model directory.", + ) + parser.add_argument( + "--image", + type=str, + default="http://images.cocodataset.org/val2017/000000039769.jpg", + help="URL or path of the image to process.", + ) + parser.add_argument( + "--prompt", + type=str, + default="USER: \nWhat are these?\nASSISTANT:", + help="Prompt to use for the model.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--temperature", type=float, default=0.3, help="Temperature for sampling." + ) + return parser.parse_args() + + +def load_image(image_source): + if image_source.startswith(("http://", "https://")): + try: + response = requests.get(image_source, stream=True) + response.raise_for_status() + return Image.open(response.raw) + except requests.HTTPError as e: + print(f"Failed to load image from URL: {e}") + return None + elif os.path.isfile(image_source): + try: + return Image.open(image_source) + except IOError as e: + print(f"Failed to load image from path: {e}") + return None + else: + print("The image source is neither a valid URL nor a file path.") + return None -processor = AutoProcessor.from_pretrained(MODEL_PATH) -model = LlavaModel.from_pretrained(MODEL_PATH) +def initialize_model(model_path): + processor = AutoProcessor.from_pretrained(model_path) + model = LlavaModel.from_pretrained(model_path) + return processor, model -values = processor(prompt, raw_image, return_tensors="np") -pixel_values = mx.array(values["pixel_values"]) -input_ids = mx.array(values["input_ids"]) -input_embeds = model(input_ids, pixel_values) -max_tokens = 100 -temperature = 0.3 +def prepare_inputs(processor, image, prompt): + inputs = processor(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + return input_ids, pixel_values -def sample(logits, temp=0.0): - if temp == 0: +def sample(logits, temperature=0.0): + if temperature == 0: return mx.argmax(logits, axis=-1) else: - return mx.random.categorical(logits * (1 / temp)) + return mx.random.categorical(logits * (1 / temperature)) -def generate(y: mx.array, model: nn.Module, temp: float = 0.0, cache=None): - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] +def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): + input_embeds = model.get_input_embeddings(input_ids, pixel_values) + logits, cache = model.language_model( + input_ids, cache=None, inputs_embeds=input_embeds + ) + logits = logits[:, -1, :] + y = sample(logits, temperature=temperature) + tokens = [y.item()] - y = sample(logits, temp=temp) + for _ in range(max_tokens): + logits, cache = model.language_model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits, temperature) token = y.item() + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + + return processor.tokenizer.decode(tokens) + - yield token +def main(): + args = parse_arguments() + raw_image = load_image(args.image) + if raw_image is None: + return + processor, model = initialize_model(args.model) + input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt) + print(args.prompt) + generated_text = generate_text( + input_ids, pixel_values, model, processor, args.max_tokens, args.temperature + ) + print(generated_text) -logits, cache = model.language_model(input_ids, cache=None, inputs_embeds=input_embeds) -logits = logits[:, -1, :] -y = sample(logits, temp=temperature) -tokens = [y.item()] -for token, _ in zip( - generate(y, model.language_model, temperature, cache=cache), - range(max_tokens), -): - if token == processor.tokenizer.eos_token_id: - break - tokens.append(token) -print(processor.tokenizer.decode(tokens)) +if __name__ == "__main__": + main() diff --git a/llava/llava.py b/llava/llava.py index 01e76122a..4af5783d7 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -61,7 +61,7 @@ def __init__(self, config: LlaVAConfig): self.vision_feature_layer = config.vision_feature_layer self.vision_feature_select_strategy = config.vision_feature_select_strategy - def __call__( + def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, From d8f7b895e1bf197cbfe965e209d5b559da593f37 Mon Sep 17 00:00:00 2001 From: anchen Date: Sun, 25 Feb 2024 01:00:56 +1100 Subject: [PATCH 14/34] chore: clean up --- llava/.gitignore | 162 ---------------------------------------------- llava/generate.py | 14 ++-- llava/test.py | 20 +++--- llava/utils.py | 31 +++++++++ 4 files changed, 49 insertions(+), 178 deletions(-) delete mode 100644 llava/.gitignore create mode 100644 llava/utils.py diff --git a/llava/.gitignore b/llava/.gitignore deleted file mode 100644 index bc0a54fe8..000000000 --- a/llava/.gitignore +++ /dev/null @@ -1,162 +0,0 @@ -**mlx_model# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -models \ No newline at end of file diff --git a/llava/generate.py b/llava/generate.py index df92c97f7..c52645818 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -6,6 +6,7 @@ import requests from PIL import Image from transformers import AutoProcessor +from utils import get_model_path from llava import LlavaModel @@ -17,8 +18,8 @@ def parse_arguments(): parser.add_argument( "--model", type=str, - default="models/llava-hf/llava-1.5-7b-hf", - help="Path to the model directory.", + default="llava-hf/llava-1.5-7b-hf", + help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( "--image", @@ -30,7 +31,7 @@ def parse_arguments(): "--prompt", type=str, default="USER: \nWhat are these?\nASSISTANT:", - help="Prompt to use for the model.", + help="Message to be processed by the model.", ) parser.add_argument( "--max-tokens", @@ -39,7 +40,7 @@ def parse_arguments(): help="Maximum number of tokens to generate.", ) parser.add_argument( - "--temperature", type=float, default=0.3, help="Temperature for sampling." + "--temp", type=float, default=0.3, help="Temperature for sampling." ) return parser.parse_args() @@ -66,7 +67,8 @@ def load_image(image_source): def initialize_model(model_path): processor = AutoProcessor.from_pretrained(model_path) - model = LlavaModel.from_pretrained(model_path) + + model = LlavaModel.from_pretrained(get_model_path(model_path)) return processor, model @@ -115,7 +117,7 @@ def main(): input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt) print(args.prompt) generated_text = generate_text( - input_ids, pixel_values, model, processor, args.max_tokens, args.temperature + input_ids, pixel_values, model, processor, args.max_tokens, args.temp ) print(generated_text) diff --git a/llava/test.py b/llava/test.py index 4cca0c9af..64652324f 100644 --- a/llava/test.py +++ b/llava/test.py @@ -6,14 +6,18 @@ import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration +from utils import get_model_path from llava import LlavaModel -MODEL_PATH = "models/llava-hf/llava-1.5-7b-hf" +MODEL_PATH = "llava-hf/llava-1.5-7b-hf" +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" def load_mlx_models(path): - model = LlavaModel.from_pretrained(path) + model_path = get_model_path(path) + model = LlavaModel.from_pretrained(model_path) model.eval() return model @@ -32,12 +36,10 @@ def setUpClass(cls): cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) def test_image_features(self): - prompt = "USER: \nWhat are these?\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) vision_feature_layer = -2 with torch.no_grad(): - pixel_values = self.proc(prompt, raw_image, return_tensors="pt")[ + pixel_values = self.proc(PROMPT, raw_image, return_tensors="pt")[ "pixel_values" ] @@ -73,12 +75,10 @@ def test_image_features(self): ) def test_merge_input_ids_with_image_features(self): - prompt = "USER: \nWhat are these?\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) vision_feature_layer = -2 with torch.no_grad(): - values = self.proc(prompt, raw_image, return_tensors="pt") + values = self.proc(PROMPT, raw_image, return_tensors="pt") pixel_values = values["pixel_values"] input_ids = values["input_ids"] diff --git a/llava/utils.py b/llava/utils.py new file mode 100644 index 000000000..0514b12d5 --- /dev/null +++ b/llava/utils.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from huggingface_hub import snapshot_download + + +def get_model_path(path_or_hf_repo: str) -> Path: + """ + Ensures the model is available locally. If the path does not exist locally, + it is downloaded from the Hugging Face Hub. + + Args: + path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + + Returns: + Path: The path to the model. + """ + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + return model_path From 371a8071c8e6f4c1156b8e87fed386737bbaae41 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Sat, 24 Feb 2024 14:24:23 -0500 Subject: [PATCH 15/34] add: warning to user if no token despite using one --- llava/llava.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llava/llava.py b/llava/llava.py index 4af5783d7..0166e9f50 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -102,14 +102,22 @@ def _merge_input_ids_with_image_features( batch_size, sequence_length = input_ids.shape special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) + # if no special image tokens found, return a warning + if np.all(num_special_image_tokens == 0): + logging.warning( + "No special image tokens found in the input. Please make sure to include in your prompt." + ) + max_embed_dim = ( np.max(num_special_image_tokens) * (num_image_patches - 1) ) + sequence_length - non_image_indices = np.where(input_ids != self.config.image_token_index) + non_image_indices = np.where( + input_ids != self.config.image_token_index) new_token_positions = ( - np.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), axis=-1) + np.cumsum((special_image_token_mask * + (num_image_patches - 1) + 1), axis=-1) - 1 ) text_to_overwrite = new_token_positions[non_image_indices] @@ -145,7 +153,8 @@ def from_pretrained(path: str): ) if isinstance(model_config.text_config, dict): - model_config.text_config = TextConfig.from_dict(model_config.text_config) + model_config.text_config = TextConfig.from_dict( + model_config.text_config) model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) From 449f7d0e126fc9b4b137275e08299099629148a5 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Sat, 24 Feb 2024 14:35:40 -0500 Subject: [PATCH 16/34] add: __call__ to LlavaModel --- llava/llava.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llava/llava.py b/llava/llava.py index 0166e9f50..aaa3987eb 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -138,6 +138,14 @@ def _merge_input_ids_with_image_features( return mx.array(final_embedding) + def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): + input_ids, pixel_values = self.get_input_embeddings( + input_ids, pixel_values) + logits, cache = self.language_model( + input_ids, cache=cache, inputs_embeds=input_ids + ) + return logits, cache + @staticmethod def from_pretrained(path: str): path = Path(path) From a1cab2b6f2a5935e6dcb3596bde9f5982e103706 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Sat, 24 Feb 2024 14:36:00 -0500 Subject: [PATCH 17/34] add: call to LlavaModel --- llava/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 llava/.gitignore diff --git a/llava/.gitignore b/llava/.gitignore new file mode 100644 index 000000000..ab6d270d8 --- /dev/null +++ b/llava/.gitignore @@ -0,0 +1 @@ +**.ipynb \ No newline at end of file From 8e6b2f5f435eb1d88e12e8c47fb55635d8c90103 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 08:02:20 -0500 Subject: [PATCH 18/34] update fp --- llava/llava.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/llava/llava.py b/llava/llava.py index aaa3987eb..48c839a68 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -69,13 +69,17 @@ def get_input_embeddings( if pixel_values is None: return self.language_model(input_ids) + # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + # get the ouptut hidden states from the vision model _, _, hidden_states = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) + # select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy == "default": + # default strategy is to select all the hidden states except the first one (CLS token?) selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature @@ -83,8 +87,9 @@ def get_input_embeddings( raise ValueError( f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" ) - + # pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) + # insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) @@ -94,21 +99,25 @@ def get_input_embeddings( def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): + image_features = np.array(image_features) inputs_embeds = np.array(inputs_embeds) input_ids = np.array(input_ids) _, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape + special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) + # if no special image tokens found, return a warning if np.all(num_special_image_tokens == 0): logging.warning( "No special image tokens found in the input. Please make sure to include in your prompt." ) - max_embed_dim = ( + # calculate the final sequence length. Will be the original sequence length + the # of image tokens to be inserted in. + final_sequence_length = ( np.max(num_special_image_tokens) * (num_image_patches - 1) ) + sequence_length @@ -123,7 +132,8 @@ def _merge_input_ids_with_image_features( text_to_overwrite = new_token_positions[non_image_indices] final_embedding = np.zeros( - (batch_size, max_embed_dim, embed_dim), dtype=inputs_embeds.dtype + (batch_size, final_sequence_length, + embed_dim), dtype=inputs_embeds.dtype ) final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[ @@ -139,10 +149,10 @@ def _merge_input_ids_with_image_features( return mx.array(final_embedding) def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): - input_ids, pixel_values = self.get_input_embeddings( + input_embddings = self.get_input_embeddings( input_ids, pixel_values) logits, cache = self.language_model( - input_ids, cache=cache, inputs_embeds=input_ids + input_ids, cache=cache, inputs_embeds=input_embddings ) return logits, cache From 823411cd316d313240b092ea6a0a64270689e821 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 08:14:54 -0500 Subject: [PATCH 19/34] clean up var names --- llava/language.py | 78 +++++++++++++++++++++++++++-------------------- llava/llava.py | 9 +++++- llava/vision.py | 20 +++++++++--- 3 files changed, 68 insertions(+), 39 deletions(-) diff --git a/llava/language.py b/llava/language.py index bb9078155..56369163d 100644 --- a/llava/language.py +++ b/llava/language.py @@ -37,10 +37,12 @@ def __post_init__(self): if self.rope_scaling: required_keys = {"factor", "type"} if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") + raise ValueError( + f"rope_scaling must contain keys {required_keys}") if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") + raise ValueError( + "rope_scaling 'type' currently only supports 'linear'") class RMSNorm(nn.Module): @@ -58,16 +60,16 @@ def __call__(self, x): class Attention(nn.Module): - def __init__(self, args: TextConfig): + def __init__(self, config: TextConfig): super().__init__() - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads self.repeats = n_heads // n_kv_heads - head_dim = args.hidden_size // n_heads + head_dim = config.hidden_size // n_heads self.scale = head_dim**-0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) @@ -76,14 +78,14 @@ def __init__(self, args: TextConfig): self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + 1 / config.rope_scaling["factor"] + if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1 ) self.rope = nn.RoPE( head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, + traditional=config.rope_traditional, + base=config.rope_theta, scale=rope_scale, ) @@ -100,7 +102,8 @@ def __call__( # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape( + B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.repeats > 1: keys = mx.repeat(keys, self.repeats, axis=1) @@ -119,7 +122,8 @@ def __call__( scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + scores = mx.softmax(scores.astype(mx.float32), + axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -136,15 +140,17 @@ def __call__(self, x) -> mx.array: class TransformerBlock(nn.Module): - def __init__(self, args: TextConfig): + def __init__(self, config: TextConfig): super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = Attention(config) + self.mlp = MLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.config = config def __call__( self, @@ -160,17 +166,17 @@ def __call__( class Llama(nn.Module): - def __init__(self, args: TextConfig): + def __init__(self, config: TextConfig): super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers + self.config = config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + TransformerBlock(config=config) for _ in range(config.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, @@ -186,7 +192,8 @@ def __call__( mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask( + h.shape[1]) mask = mask.astype(h.dtype) if cache is None: @@ -199,11 +206,16 @@ def __call__( class LanguageModel(nn.Module): - def __init__(self, args: TextConfig): + def __init__(self, config: TextConfig): super().__init__() - self.model_type = args.model_type - self.model = Llama(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.model_type = config.model_type + if self.model_type != "llama": + raise ValueError( + f"Model type {self.model_type} not supported. Currently only 'llama' is supported" + ) + self.model = Llama(config) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) def __call__( self, diff --git a/llava/llava.py b/llava/llava.py index 48c839a68..b17be3ce3 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -34,13 +34,20 @@ def from_dict(cls, params): ) +def quick_gelu(x: mx.array) -> mx.array: + """ + A fast GELU approximation https://github.com/hendrycks/GELUs + """ + return x * mx.sigmoid(1.702 * x) + + class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=True ) - self.gelu = nn.GELU() + self.gelu = quick_gelu self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=True ) diff --git a/llava/vision.py b/llava/vision.py index ed7f7a46e..339bc3994 100644 --- a/llava/vision.py +++ b/llava/vision.py @@ -116,9 +116,11 @@ def __init__(self, config: VisionConfig): self.self_attn = Attention( config.hidden_size, config.num_attention_heads, bias=True ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_eps) self.mlp = MLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_eps) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) @@ -132,7 +134,8 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: class Encoder(nn.Module): def __init__(self, config: VisionConfig): super().__init__() - self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + self.layers = [EncoderLayer(config) + for _ in range(config.num_hidden_layers)] class VisionEmbeddings(nn.Module): @@ -155,12 +158,14 @@ def __init__(self, config: VisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_embedding = nn.Embedding( + self.num_positions, self.embed_dim) def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) - patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + patch_embeddings = mx.flatten( + patch_embeddings, start_axis=1, end_axis=2) embed_dim = patch_embeddings.shape[-1] cls_embeddings = mx.broadcast_to( self.class_embedding, (batch_size, 1, embed_dim) @@ -200,6 +205,11 @@ def __call__( class VisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() + + self.model_type = config.model_type + if self.model_type != "clip_vision_model": + raise ValueError(f"Unsupported model type: {self.model_type}") + self.vision_model = ClipVisionModel(config) def __call__( From 6bc06c8168e25fb645b2ce5d33490da1466de165 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 08:26:26 -0500 Subject: [PATCH 20/34] update: native GeLU --- llava/llava.py | 11 ++--------- llava/vision.py | 9 +-------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/llava/llava.py b/llava/llava.py index b17be3ce3..a6e7c6f16 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Optional import mlx.core as mx import mlx.nn as nn @@ -34,20 +34,13 @@ def from_dict(cls, params): ) -def quick_gelu(x: mx.array) -> mx.array: - """ - A fast GELU approximation https://github.com/hendrycks/GELUs - """ - return x * mx.sigmoid(1.702 * x) - - class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=True ) - self.gelu = quick_gelu + self.gelu = nn.GELU(approx='fast') self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=True ) diff --git a/llava/vision.py b/llava/vision.py index 339bc3994..a61d99645 100644 --- a/llava/vision.py +++ b/llava/vision.py @@ -35,13 +35,6 @@ def from_dict(cls, params): ) -def quick_gelu(x: mx.array) -> mx.array: - """ - A fast GELU approximation https://github.com/hendrycks/GELUs - """ - return x * mx.sigmoid(1.702 * x) - - class Attention(nn.Module): def __init__( self, @@ -99,7 +92,7 @@ def __call__(self, queries, keys, values, mask=None): class MLP(nn.Module): def __init__(self, config: VisionConfig): super().__init__() - self.activation_fn = quick_gelu + self.activation_fn = nn.GELU(approx='fast') self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) From feec5ec30c8370ce634b3bfbf7cc343095e09f42 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Tue, 27 Feb 2024 19:16:21 -0500 Subject: [PATCH 21/34] Cleanup --- llava/generate.py | 7 +++---- llava/llava.py | 3 ++- llava/test.py | 38 ++++++++++++++++++++++++++++++++++---- llava/utils.py | 1 - 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index c52645818..36170151c 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -2,7 +2,6 @@ import os import mlx.core as mx -import mlx.nn as nn import requests from PIL import Image from transformers import AutoProcessor @@ -87,9 +86,9 @@ def sample(logits, temperature=0.0): def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): - input_embeds = model.get_input_embeddings(input_ids, pixel_values) - logits, cache = model.language_model( - input_ids, cache=None, inputs_embeds=input_embeds + + logits, cache = model( + input_ids, pixel_values ) logits = logits[:, -1, :] y = sample(logits, temperature=temperature) diff --git a/llava/llava.py b/llava/llava.py index a6e7c6f16..49821239c 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -40,7 +40,8 @@ def __init__(self, config: LlaVAConfig): self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=True ) - self.gelu = nn.GELU(approx='fast') + + self.gelu = nn.GELU() self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=True ) diff --git a/llava/test.py b/llava/test.py index 64652324f..b2911507c 100644 --- a/llava/test.py +++ b/llava/test.py @@ -1,7 +1,6 @@ import unittest import mlx.core as mx -import numpy as np import requests import torch from PIL import Image @@ -28,7 +27,7 @@ def load_hf_models(path): return model -class TestCLIP(unittest.TestCase): +class TestVisionTower(unittest.TestCase): @classmethod def setUpClass(cls): cls.mx_llava = load_mlx_models(MODEL_PATH) @@ -44,7 +43,8 @@ def test_image_features(self): ] hf_pixel_values = pixel_values - mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + mx_pixel_values = mx.array( + pixel_values.numpy()).transpose(0, 2, 3, 1) _, _, hidden_states = self.mx_llava.vision_tower( mx_pixel_values, @@ -83,7 +83,8 @@ def test_merge_input_ids_with_image_features(self): input_ids = values["input_ids"] hf_pixel_values = pixel_values - mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) + mx_pixel_values = mx.array( + pixel_values.numpy()).transpose(0, 2, 3, 1) _, _, hidden_states = self.mx_llava.vision_tower( mx_pixel_values, @@ -130,5 +131,34 @@ def test_merge_input_ids_with_image_features(self): ) +class TestLlaVa(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) + + def test_generated_token(self): + raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) + with torch.no_grad(): + hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt") + hf_outputs = self.hf_llava(**hf_inputs) + hf_logits = hf_outputs.logits + + mx_inputs = self.proc(PROMPT, raw_image, return_tensors="np") + pixel_values = mx.array(mx_inputs["pixel_values"]) + input_ids = mx.array(mx_inputs["input_ids"]) + + mx_logits, _ = self.mx_llava(input_ids, pixel_values) + + self.assertTrue( + mx.allclose( + mx_logits[:, -1, :].argmax(axis=-1), + mx.array(hf_logits.numpy())[:, -1, :].argmax(axis=-1), + atol=1e-2, + ) + ) + + if __name__ == "__main__": unittest.main() diff --git a/llava/utils.py b/llava/utils.py index 0514b12d5..af8b47ac8 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -1,5 +1,4 @@ from pathlib import Path - from huggingface_hub import snapshot_download From d76fd40bb3739c580e282b97d7cea4444b972f01 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Tue, 27 Feb 2024 19:44:38 -0500 Subject: [PATCH 22/34] update generate and readme --- llava/README.md | 40 ++++++++++++++++++++++++++++++++++++++++ llava/generate.py | 33 ++------------------------------- llava/llava.py | 2 ++ llava/utils.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 31 deletions(-) diff --git a/llava/README.md b/llava/README.md index c10506421..2751c3deb 100644 --- a/llava/README.md +++ b/llava/README.md @@ -1 +1,41 @@ # LLaVA + +An example of LLaVA: Large Language and Vission Assistant in MLX. LLlava is a multi-modal model that can generate text from images and text prompts. [^1] + +## Setup: + +Install the dependencies: + +```bash +pip install -r requirements.txt +``` + +## Run + +You can use LlaVA model to ask questions about images. + +The python snippet below shows how to use the model to ask questions about an image. + +```python +from llava import LlavaModel +from transformers import AutoProcessor +from utils import load_image, prepare_inputs +from generate import generate_text +model_path = 'llava-hf/llava-1.5-7b-hf' + +processor = AutoProcessor.from_pretrained(model_path) +model = LlavaModel.from_pretrained(model_path) + +max_tokens, temperature = 128, 0. + +prompt = "USER: \nWhat are these?\nASSISTANT:" +image = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = load_image(image) +input_ids, pixel_values = prepare_inputs(prompt, image, processor) + +reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature) + +print(reply) +``` + +[^1]: Please refer to original LlaVA library for more details: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA) diff --git a/llava/generate.py b/llava/generate.py index 36170151c..2264115ee 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -1,11 +1,9 @@ import argparse -import os + import mlx.core as mx -import requests -from PIL import Image from transformers import AutoProcessor -from utils import get_model_path +from utils import get_model_path, load_image, prepare_inputs from llava import LlavaModel @@ -44,26 +42,6 @@ def parse_arguments(): return parser.parse_args() -def load_image(image_source): - if image_source.startswith(("http://", "https://")): - try: - response = requests.get(image_source, stream=True) - response.raise_for_status() - return Image.open(response.raw) - except requests.HTTPError as e: - print(f"Failed to load image from URL: {e}") - return None - elif os.path.isfile(image_source): - try: - return Image.open(image_source) - except IOError as e: - print(f"Failed to load image from path: {e}") - return None - else: - print("The image source is neither a valid URL nor a file path.") - return None - - def initialize_model(model_path): processor = AutoProcessor.from_pretrained(model_path) @@ -71,13 +49,6 @@ def initialize_model(model_path): return processor, model -def prepare_inputs(processor, image, prompt): - inputs = processor(prompt, image, return_tensors="np") - pixel_values = mx.array(inputs["pixel_values"]) - input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values - - def sample(logits, temperature=0.0): if temperature == 0: return mx.argmax(logits, axis=-1) diff --git a/llava/llava.py b/llava/llava.py index 49821239c..9594666fa 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -11,6 +11,7 @@ import numpy as np from language import LanguageModel, TextConfig from vision import VisionConfig, VisionModel +from utils import get_model_path @dataclass @@ -159,6 +160,7 @@ def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): @staticmethod def from_pretrained(path: str): + path = get_model_path(path) path = Path(path) with open(path / "config.json", "r") as f: diff --git a/llava/utils.py b/llava/utils.py index af8b47ac8..cc4180806 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -1,5 +1,9 @@ from pathlib import Path from huggingface_hub import snapshot_download +import os +import requests +from PIL import Image +import mlx.core as mx def get_model_path(path_or_hf_repo: str) -> Path: @@ -28,3 +32,30 @@ def get_model_path(path_or_hf_repo: str) -> Path: ) ) return model_path + + +def load_image(image_source): + if image_source.startswith(("http://", "https://")): + try: + response = requests.get(image_source, stream=True) + response.raise_for_status() + return Image.open(response.raw) + except requests.HTTPError as e: + print(f"Failed to load image from URL: {e}") + return None + elif os.path.isfile(image_source): + try: + return Image.open(image_source) + except IOError as e: + print(f"Failed to load image from path: {e}") + return None + else: + print("The image source is neither a valid URL nor a file path.") + return None + + +def prepare_inputs(processor, image, prompt): + inputs = processor(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + return input_ids, pixel_values From 49f928a23fd7eb7c30d37ad22c26bd52e32b6fab Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Tue, 27 Feb 2024 19:46:22 -0500 Subject: [PATCH 23/34] remove todo comment --- llava/llava.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llava/llava.py b/llava/llava.py index 9594666fa..1679e50a3 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -81,7 +81,6 @@ def get_input_embeddings( selected_image_feature = hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy == "default": - # default strategy is to select all the hidden states except the first one (CLS token?) selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature From c2b8463d569336025c3013a5cdefad47b4c957b0 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Tue, 27 Feb 2024 19:55:21 -0500 Subject: [PATCH 24/34] rearrange tests --- llava/test.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llava/test.py b/llava/test.py index b2911507c..2f0d8be1a 100644 --- a/llava/test.py +++ b/llava/test.py @@ -74,6 +74,14 @@ def test_image_features(self): ) ) + +class TestLlaVA(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mx_llava = load_mlx_models(MODEL_PATH) + cls.hf_llava = load_hf_models(MODEL_PATH) + cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) + def test_merge_input_ids_with_image_features(self): raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) vision_feature_layer = -2 @@ -130,15 +138,7 @@ def test_merge_input_ids_with_image_features(self): ) ) - -class TestLlaVa(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.mx_llava = load_mlx_models(MODEL_PATH) - cls.hf_llava = load_hf_models(MODEL_PATH) - cls.proc = AutoProcessor.from_pretrained(MODEL_PATH) - - def test_generated_token(self): + def test_generated_tokens(self): raw_image = Image.open(requests.get(IMAGE_FILE, stream=True).raw) with torch.no_grad(): hf_inputs = self.proc(PROMPT, raw_image, return_tensors="pt") From 25a65cf2261df9c7e52317fb2ad5ec69ee0b5db0 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Tue, 27 Feb 2024 19:57:33 -0500 Subject: [PATCH 25/34] fix example code --- llava/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llava/README.md b/llava/README.md index 2751c3deb..a44e88b82 100644 --- a/llava/README.md +++ b/llava/README.md @@ -31,7 +31,7 @@ max_tokens, temperature = 128, 0. prompt = "USER: \nWhat are these?\nASSISTANT:" image = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(image) -input_ids, pixel_values = prepare_inputs(prompt, image, processor) +input_ids, pixel_values = prepare_inputs(processor, image, prompt) reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature) From c2c94112389883e65f7437d9bdf1cd62add2ce4f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 27 Feb 2024 21:19:35 -0800 Subject: [PATCH 26/34] nits in README --- llava/README.md | 10 +++++----- llava/requirements.txt | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llava/README.md b/llava/README.md index a44e88b82..abe7d3a4c 100644 --- a/llava/README.md +++ b/llava/README.md @@ -1,6 +1,7 @@ # LLaVA -An example of LLaVA: Large Language and Vission Assistant in MLX. LLlava is a multi-modal model that can generate text from images and text prompts. [^1] +An example of LLaVA: Large Language and Vission Assistant in MLX. LLlava is a +multi-modal model that can generate text from images and text prompts.[^1] ## Setup: @@ -12,9 +13,7 @@ pip install -r requirements.txt ## Run -You can use LlaVA model to ask questions about images. - -The python snippet below shows how to use the model to ask questions about an image. +You can use LlaVA model to ask questions about images. For example: ```python from llava import LlavaModel @@ -38,4 +37,5 @@ reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, tem print(reply) ``` -[^1]: Please refer to original LlaVA library for more details: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA) +[^1]: Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more + information. diff --git a/llava/requirements.txt b/llava/requirements.txt index a6f2f1ca7..74f826ea6 100644 --- a/llava/requirements.txt +++ b/llava/requirements.txt @@ -3,4 +3,4 @@ numpy transformers torch huggingface_hub -Pillow \ No newline at end of file +Pillow From 8301c43749113a7aa8c3898fddd3b5d6769064e2 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Wed, 28 Feb 2024 08:30:12 -0500 Subject: [PATCH 27/34] update readme --- llava/README.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llava/README.md b/llava/README.md index abe7d3a4c..4cb0a99d8 100644 --- a/llava/README.md +++ b/llava/README.md @@ -13,7 +13,15 @@ pip install -r requirements.txt ## Run -You can use LlaVA model to ask questions about images. For example: +You can use LlaVA model to ask questions about images. + +For example using the command line: + +```bash +python generate.py --model_path llava-hf/llava-1.5-7b-hf --image "http://images.cocodataset.org/val2017/000000039769.jpg" --prompt "USER: \nWhat are these?\nASSISTANT:" --max_tokens 128 --temperature 0 +``` + +Or directly in Python: ```python from llava import LlavaModel @@ -37,5 +45,6 @@ reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, tem print(reply) ``` -[^1]: Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more - information. +[^1]: + Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more + information. From 5c8f67d1204263283bb4bcf57fea14d8fff2106a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 28 Feb 2024 11:13:10 -0800 Subject: [PATCH 28/34] nit in readme --- llava/README.md | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/llava/README.md b/llava/README.md index 4cb0a99d8..a36735da4 100644 --- a/llava/README.md +++ b/llava/README.md @@ -13,15 +13,34 @@ pip install -r requirements.txt ## Run -You can use LlaVA model to ask questions about images. +You can use LLaVA to ask questions about images. -For example using the command line: +For example, using the command line: ```bash -python generate.py --model_path llava-hf/llava-1.5-7b-hf --image "http://images.cocodataset.org/val2017/000000039769.jpg" --prompt "USER: \nWhat are these?\nASSISTANT:" --max_tokens 128 --temperature 0 +python generate.py \ + --model llava-hf/llava-1.5-7b-hf \ + --image "http://images.cocodataset.org/val2017/000000039769.jpg" \ + --prompt "USER: \nWhat are these?\nASSISTANT:" \ + --max-tokens 128 ``` -Or directly in Python: +This uses the following image: + +![alt text](http://images.cocodataset.org/val2017/000000039769.jpg) + +And generates output similar to: + +```shell +These are two cats, one of which is sleeping and the other one is awake. + +The sleeping cat is lying on a couch, while the awake cat is also on the couch, +positioned near the sleeping cat. The couch appears to be red, and there is a +remote control placed nearby. The cats are comfortably resting on the couch, +enjoying each other's company. +``` + +You can also use LLaVA in Python: ```python from llava import LlavaModel @@ -45,6 +64,13 @@ reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, tem print(reply) ``` + + +The model output: + +``` +``` + [^1]: Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more information. From cd77bcf4c41bd52a878914c0fa578e7cb53b083d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 28 Feb 2024 11:13:27 -0800 Subject: [PATCH 29/34] nits in README --- llava/README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/llava/README.md b/llava/README.md index a36735da4..143803e44 100644 --- a/llava/README.md +++ b/llava/README.md @@ -64,13 +64,6 @@ reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, tem print(reply) ``` - - -The model output: - -``` -``` - [^1]: Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more information. From b39c251ba19af4ffdf97550450b73628f22b55ad Mon Sep 17 00:00:00 2001 From: anchen Date: Thu, 29 Feb 2024 08:59:14 +1100 Subject: [PATCH 30/34] chore(llava): refactor image embedding merging logic --- llava/llava.py | 80 +++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 44 deletions(-) diff --git a/llava/llava.py b/llava/llava.py index 1679e50a3..0bff0f8a0 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -10,8 +10,8 @@ import mlx.nn as nn import numpy as np from language import LanguageModel, TextConfig -from vision import VisionConfig, VisionModel from utils import get_model_path +from vision import VisionConfig, VisionModel @dataclass @@ -100,58 +100,51 @@ def get_input_embeddings( def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in `input_ids` and assume batch size is 1 + image_positions = [ + idx + for idx, token_id in enumerate(input_ids[0]) + if token_id == image_token_index + ] - image_features = np.array(image_features) - inputs_embeds = np.array(inputs_embeds) - input_ids = np.array(input_ids) - - _, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = np.sum(special_image_token_mask, axis=-1) - - # if no special image tokens found, return a warning - if np.all(num_special_image_tokens == 0): - logging.warning( - "No special image tokens found in the input. Please make sure to include in your prompt." + if len(image_positions) != num_images: + raise ValueError( + f"The input provided to the model is incorrect. The number of image tokens is {len(image_positions)}, but the number of images given to the model is {num_images}." ) - # calculate the final sequence length. Will be the original sequence length + the # of image tokens to be inserted in. - final_sequence_length = ( - np.max(num_special_image_tokens) * (num_image_patches - 1) - ) + sequence_length + text_segments = [] + start_idx = 0 - non_image_indices = np.where( - input_ids != self.config.image_token_index) + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 - new_token_positions = ( - np.cumsum((special_image_token_mask * - (num_image_patches - 1) + 1), axis=-1) - - 1 - ) - text_to_overwrite = new_token_positions[non_image_indices] + if start_idx < inputs_embeds.shape[1]: + text_segments.append(inputs_embeds[:, start_idx:]) - final_embedding = np.zeros( - (batch_size, final_sequence_length, - embed_dim), dtype=inputs_embeds.dtype - ) + # Reshape image feature from (num_images, num_image_patches, embed_dim) to (num_images*num_image_patches, embed_dim) + image_embeddings = image_features.reshape(-1, image_features.shape[-1]) - final_embedding[non_image_indices[0], text_to_overwrite, :] = inputs_embeds[ - non_image_indices - ] + final_embeddings = [] + for i, text_segment in enumerate(text_segments): + final_embeddings.append(text_segment[0]) + if i < len(image_positions): + # Add a slice of image embeddings corresponding to the current position. + # This effectively replaces one token with its associated num_image_patches embeddings. + final_embeddings.append(image_embeddings[i : i + num_image_patches + 1]) - image_to_overwrite = np.all(final_embedding == 0, axis=-1) - reshaped_image_features = image_features.reshape(-1, embed_dim) - final_embedding[image_to_overwrite, :] = reshaped_image_features[ - : np.sum(image_to_overwrite) - ] + # This creates a final embeding in shape (1, num_image_patches*num_images + sequence_len, embed_dim) representing the merged sequence of text and image embeddings. + final_embeddings = mx.concatenate(final_embeddings, axis=0).reshape( + 1, -1, embed_dim + ) - return mx.array(final_embedding) + return final_embeddings def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): - input_embddings = self.get_input_embeddings( - input_ids, pixel_values) + input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits, cache = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) @@ -173,8 +166,7 @@ def from_pretrained(path: str): ) if isinstance(model_config.text_config, dict): - model_config.text_config = TextConfig.from_dict( - model_config.text_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) From 935ebb574b151a986ccdf2da10d64ae06e029d6c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 29 Feb 2024 20:22:34 -0800 Subject: [PATCH 31/34] min mlx version --- llava/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llava/requirements.txt b/llava/requirements.txt index 74f826ea6..a11d91482 100644 --- a/llava/requirements.txt +++ b/llava/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.5.0 numpy transformers torch From 683b7c4dfdcfa0908f0601262b4f32095db48ef2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Mar 2024 07:16:18 -0800 Subject: [PATCH 32/34] nits in readmes --- README.md | 1 + llava/README.md | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 15c9cca1c..b3404038d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Some more useful examples are listed below. ### Multimodal models - Joint text and image embeddings with [CLIP](clip). +- Text generation from image and text inputs with [LLaVA](llava). ### Other Models diff --git a/llava/README.md b/llava/README.md index 143803e44..6ce78a977 100644 --- a/llava/README.md +++ b/llava/README.md @@ -1,9 +1,9 @@ # LLaVA -An example of LLaVA: Large Language and Vission Assistant in MLX. LLlava is a -multi-modal model that can generate text from images and text prompts.[^1] +An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is +a multimodal model that can generate text given combined image and text inputs. -## Setup: +## Setup Install the dependencies: @@ -31,7 +31,7 @@ This uses the following image: And generates output similar to: -```shell +``` These are two cats, one of which is sleeping and the other one is awake. The sleeping cat is lying on a couch, while the awake cat is also on the couch, From b37891d17e2727dd6ab94cecb5870434a63c407a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Mar 2024 08:47:20 -0800 Subject: [PATCH 33/34] fix cli prompt, some nits --- llava/README.md | 26 ++++++++++---------------- llava/generate.py | 25 +++++++++++++------------ llava/language.py | 33 +++++++++++++-------------------- llava/test.py | 8 ++++---- llava/utils.py | 28 ++++++++++++++++++---------- llava/vision.py | 28 ++++++++++++---------------- 6 files changed, 70 insertions(+), 78 deletions(-) diff --git a/llava/README.md b/llava/README.md index 6ce78a977..873994db4 100644 --- a/llava/README.md +++ b/llava/README.md @@ -22,44 +22,38 @@ python generate.py \ --model llava-hf/llava-1.5-7b-hf \ --image "http://images.cocodataset.org/val2017/000000039769.jpg" \ --prompt "USER: \nWhat are these?\nASSISTANT:" \ - --max-tokens 128 + --max-tokens 128 \ + --temp 0 ``` This uses the following image: ![alt text](http://images.cocodataset.org/val2017/000000039769.jpg) -And generates output similar to: +And generates the output: ``` -These are two cats, one of which is sleeping and the other one is awake. - -The sleeping cat is lying on a couch, while the awake cat is also on the couch, -positioned near the sleeping cat. The couch appears to be red, and there is a -remote control placed nearby. The cats are comfortably resting on the couch, -enjoying each other's company. +These are two cats lying on a pink couch. ``` You can also use LLaVA in Python: ```python -from llava import LlavaModel -from transformers import AutoProcessor from utils import load_image, prepare_inputs -from generate import generate_text -model_path = 'llava-hf/llava-1.5-7b-hf' +from generate import load_model, generate_text -processor = AutoProcessor.from_pretrained(model_path) -model = LlavaModel.from_pretrained(model_path) +processor, model = load_model("llava-hf/llava-1.5-7b-hf") -max_tokens, temperature = 128, 0. +max_tokens, temperature = 128, 0.0 prompt = "USER: \nWhat are these?\nASSISTANT:" image = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(image) input_ids, pixel_values = prepare_inputs(processor, image, prompt) -reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature) +reply = generate_text( + input_ids, pixel_values, model, processor, max_tokens, temperature +) print(reply) ``` diff --git a/llava/generate.py b/llava/generate.py index 2264115ee..40e2a9b9f 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -1,5 +1,7 @@ -import argparse +# Copyright © 2024 Apple Inc. +import argparse +import codecs import mlx.core as mx from transformers import AutoProcessor @@ -42,7 +44,7 @@ def parse_arguments(): return parser.parse_args() -def initialize_model(model_path): +def load_model(model_path): processor = AutoProcessor.from_pretrained(model_path) model = LlavaModel.from_pretrained(get_model_path(model_path)) @@ -58,14 +60,12 @@ def sample(logits, temperature=0.0): def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): - logits, cache = model( - input_ids, pixel_values - ) + logits, cache = model(input_ids, pixel_values) logits = logits[:, -1, :] y = sample(logits, temperature=temperature) tokens = [y.item()] - for _ in range(max_tokens): + for n in range(max_tokens - 1): logits, cache = model.language_model(y[None], cache=cache) logits = logits[:, -1, :] y = sample(logits, temperature) @@ -79,13 +79,14 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera def main(): args = parse_arguments() - raw_image = load_image(args.image) - if raw_image is None: - return + image = load_image(args.image) + processor, model = load_model(args.model) + + prompt = codecs.decode(args.prompt, "unicode_escape") + + input_ids, pixel_values = prepare_inputs(processor, image, prompt) - processor, model = initialize_model(args.model) - input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt) - print(args.prompt) + print(prompt) generated_text = generate_text( input_ids, pixel_values, model, processor, args.max_tokens, args.temp ) diff --git a/llava/language.py b/llava/language.py index 56369163d..e9023b99c 100644 --- a/llava/language.py +++ b/llava/language.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import inspect from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union @@ -37,12 +39,10 @@ def __post_init__(self): if self.rope_scaling: required_keys = {"factor", "type"} if not all(key in self.rope_scaling for key in required_keys): - raise ValueError( - f"rope_scaling must contain keys {required_keys}") + raise ValueError(f"rope_scaling must contain keys {required_keys}") if self.rope_scaling["type"] != "linear": - raise ValueError( - "rope_scaling 'type' currently only supports 'linear'") + raise ValueError("rope_scaling 'type' currently only supports 'linear'") class RMSNorm(nn.Module): @@ -79,7 +79,8 @@ def __init__(self, config: TextConfig): rope_scale = ( 1 / config.rope_scaling["factor"] - if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" + if config.rope_scaling is not None + and config.rope_scaling["type"] == "linear" else 1 ) self.rope = nn.RoPE( @@ -102,8 +103,7 @@ def __call__( # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape( - B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.repeats > 1: keys = mx.repeat(keys, self.repeats, axis=1) @@ -122,8 +122,7 @@ def __call__( scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores += mask - scores = mx.softmax(scores.astype(mx.float32), - axis=-1).astype(scores.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -146,10 +145,10 @@ def __init__(self, config: TextConfig): self.hidden_size = config.hidden_size self.self_attn = Attention(config) self.mlp = MLP(config.hidden_size, config.intermediate_size) - self.input_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) + config.hidden_size, eps=config.rms_norm_eps + ) self.config = config def __call__( @@ -192,8 +191,7 @@ def __call__( mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask( - h.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) if cache is None: @@ -214,8 +212,7 @@ def __init__(self, config: TextConfig): f"Model type {self.model_type} not supported. Currently only 'llama' is supported" ) self.model = Llama(config) - self.lm_head = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( self, @@ -232,7 +229,3 @@ def sanitize(weights): return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - - @property - def layers(self): - return self.model.layers diff --git a/llava/test.py b/llava/test.py index 2f0d8be1a..278ada200 100644 --- a/llava/test.py +++ b/llava/test.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import unittest import mlx.core as mx @@ -43,8 +45,7 @@ def test_image_features(self): ] hf_pixel_values = pixel_values - mx_pixel_values = mx.array( - pixel_values.numpy()).transpose(0, 2, 3, 1) + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) _, _, hidden_states = self.mx_llava.vision_tower( mx_pixel_values, @@ -91,8 +92,7 @@ def test_merge_input_ids_with_image_features(self): input_ids = values["input_ids"] hf_pixel_values = pixel_values - mx_pixel_values = mx.array( - pixel_values.numpy()).transpose(0, 2, 3, 1) + mx_pixel_values = mx.array(pixel_values.numpy()).transpose(0, 2, 3, 1) _, _, hidden_states = self.mx_llava.vision_tower( mx_pixel_values, diff --git a/llava/utils.py b/llava/utils.py index cc4180806..265aa11f1 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -1,9 +1,12 @@ -from pathlib import Path -from huggingface_hub import snapshot_download +# Copyright © 2024 Apple Inc. + import os +from pathlib import Path + +import mlx.core as mx import requests +from huggingface_hub import snapshot_download from PIL import Image -import mlx.core as mx def get_model_path(path_or_hf_repo: str) -> Path: @@ -35,23 +38,28 @@ def get_model_path(path_or_hf_repo: str) -> Path: def load_image(image_source): + """ + Helper function to load an image from either + a URL or file. + """ if image_source.startswith(("http://", "https://")): try: response = requests.get(image_source, stream=True) response.raise_for_status() return Image.open(response.raw) - except requests.HTTPError as e: - print(f"Failed to load image from URL: {e}") - return None + except Exception as e: + raise ValueError( + f"Failed to load image from URL: {image_source} with error {e}" + ) elif os.path.isfile(image_source): try: return Image.open(image_source) except IOError as e: - print(f"Failed to load image from path: {e}") - return None + raise ValueError(f"Failed to load image {image_source} with error: {e}") else: - print("The image source is neither a valid URL nor a file path.") - return None + raise ValueError( + f"The image {image_source} must be a valid URL or existing file." + ) def prepare_inputs(processor, image, prompt): diff --git a/llava/vision.py b/llava/vision.py index a61d99645..66287dee6 100644 --- a/llava/vision.py +++ b/llava/vision.py @@ -1,7 +1,6 @@ -import glob +# Copyright © 2024 Apple Inc. + import inspect -import json -import logging import math from dataclasses import dataclass from typing import Optional @@ -92,7 +91,7 @@ def __call__(self, queries, keys, values, mask=None): class MLP(nn.Module): def __init__(self, config: VisionConfig): super().__init__() - self.activation_fn = nn.GELU(approx='fast') + self.activation_fn = nn.GELU(approx="fast") self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) @@ -109,11 +108,9 @@ def __init__(self, config: VisionConfig): self.self_attn = Attention( config.hidden_size, config.num_attention_heads, bias=True ) - self.layer_norm1 = nn.LayerNorm( - self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = MLP(config) - self.layer_norm2 = nn.LayerNorm( - self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) @@ -127,8 +124,7 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: class Encoder(nn.Module): def __init__(self, config: VisionConfig): super().__init__() - self.layers = [EncoderLayer(config) - for _ in range(config.num_hidden_layers)] + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] class VisionEmbeddings(nn.Module): @@ -151,14 +147,12 @@ def __init__(self, config: VisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding( - self.num_positions, self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) - patch_embeddings = mx.flatten( - patch_embeddings, start_axis=1, end_axis=2) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) embed_dim = patch_embeddings.shape[-1] cls_embeddings = mx.broadcast_to( self.class_embedding, (batch_size, 1, embed_dim) @@ -218,8 +212,10 @@ def sanitize(weights): # Remove unused position_ids continue elif "patch_embedding.weight" in k: - # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW] - # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels] + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] sanitized_weights[k] = v.transpose(0, 2, 3, 1) else: sanitized_weights[k] = v From 7ace6eabbe4f0ed4ee08934030352372ebc321d5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Mar 2024 10:26:17 -0800 Subject: [PATCH 34/34] updates, slight simplify --- llava/README.md | 4 +- llava/generate.py | 43 ++++++++++++++++++--- llava/llava.py | 95 +++++++++++++++++++++-------------------------- llava/test.py | 6 +-- llava/utils.py | 69 ---------------------------------- 5 files changed, 84 insertions(+), 133 deletions(-) delete mode 100644 llava/utils.py diff --git a/llava/README.md b/llava/README.md index 873994db4..ae58766e1 100644 --- a/llava/README.md +++ b/llava/README.md @@ -39,8 +39,7 @@ These are two cats lying on a pink couch. You can also use LLaVA in Python: ```python -from utils import load_image, prepare_inputs -from generate import load_model, generate_text +from generate import load_model, prepare_inputs, generate_text processor, model = load_model("llava-hf/llava-1.5-7b-hf") @@ -48,7 +47,6 @@ max_tokens, temperature = 128, 0.0 prompt = "USER: \nWhat are these?\nASSISTANT:" image = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = load_image(image) input_ids, pixel_values = prepare_inputs(processor, image, prompt) reply = generate_text( diff --git a/llava/generate.py b/llava/generate.py index 40e2a9b9f..9535bab93 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -2,10 +2,12 @@ import argparse import codecs +from pathlib import Path import mlx.core as mx +import requests +from PIL import Image from transformers import AutoProcessor -from utils import get_model_path, load_image, prepare_inputs from llava import LlavaModel @@ -44,10 +46,42 @@ def parse_arguments(): return parser.parse_args() +def load_image(image_source): + """ + Helper function to load an image from either a URL or file. + """ + if image_source.startswith(("http://", "https://")): + try: + response = requests.get(image_source, stream=True) + response.raise_for_status() + return Image.open(response.raw) + except Exception as e: + raise ValueError( + f"Failed to load image from URL: {image_source} with error {e}" + ) + elif Path(image_source).is_file(): + try: + return Image.open(image_source) + except IOError as e: + raise ValueError(f"Failed to load image {image_source} with error: {e}") + else: + raise ValueError( + f"The image {image_source} must be a valid URL or existing file." + ) + + +def prepare_inputs(processor, image, prompt): + if isinstance(image, str): + image = load_image(image) + inputs = processor(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + return input_ids, pixel_values + + def load_model(model_path): processor = AutoProcessor.from_pretrained(model_path) - - model = LlavaModel.from_pretrained(get_model_path(model_path)) + model = LlavaModel.from_pretrained(model_path) return processor, model @@ -79,12 +113,11 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera def main(): args = parse_arguments() - image = load_image(args.image) processor, model = load_model(args.model) prompt = codecs.decode(args.prompt, "unicode_escape") - input_ids, pixel_values = prepare_inputs(processor, image, prompt) + input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) print(prompt) generated_text = generate_text( diff --git a/llava/llava.py b/llava/llava.py index 0bff0f8a0..06e560590 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -1,7 +1,8 @@ +# Copyright © 2024 Apple Inc. + import glob import inspect import json -import logging from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -9,8 +10,8 @@ import mlx.core as mx import mlx.nn as nn import numpy as np +from huggingface_hub import snapshot_download from language import LanguageModel, TextConfig -from utils import get_model_path from vision import VisionConfig, VisionModel @@ -41,7 +42,6 @@ def __init__(self, config: LlaVAConfig): self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=True ) - self.gelu = nn.GELU() self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=True @@ -73,11 +73,13 @@ def get_input_embeddings( # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) - # get the ouptut hidden states from the vision model - _, _, hidden_states = self.vision_tower( + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) - # select the hidden states from the desired layer + + # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy == "default": @@ -86,15 +88,17 @@ def get_input_embeddings( selected_image_feature = selected_image_feature else: raise ValueError( - f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" + "Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}" ) - # pass image features through the multi-modal projector + + # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) - # insert special image tokens in the input_ids + + # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) - return final_inputs_embeds def _merge_input_ids_with_image_features( @@ -103,16 +107,13 @@ def _merge_input_ids_with_image_features( image_token_index = self.config.image_token_index num_images, num_image_patches, embed_dim = image_features.shape - # Positions of tokens in `input_ids` and assume batch size is 1 - image_positions = [ - idx - for idx, token_id in enumerate(input_ids[0]) - if token_id == image_token_index - ] + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() if len(image_positions) != num_images: raise ValueError( - f"The input provided to the model is incorrect. The number of image tokens is {len(image_positions)}, but the number of images given to the model is {num_images}." + f"The number of image tokens ({len(image_positions)}) does not " + f" match the number of image inputs ({num_images})." ) text_segments = [] @@ -122,26 +123,13 @@ def _merge_input_ids_with_image_features( text_segments.append(inputs_embeds[:, start_idx:position]) start_idx = position + 1 - if start_idx < inputs_embeds.shape[1]: - text_segments.append(inputs_embeds[:, start_idx:]) - - # Reshape image feature from (num_images, num_image_patches, embed_dim) to (num_images*num_image_patches, embed_dim) - image_embeddings = image_features.reshape(-1, image_features.shape[-1]) + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] - final_embeddings = [] - for i, text_segment in enumerate(text_segments): - final_embeddings.append(text_segment[0]) - if i < len(image_positions): - # Add a slice of image embeddings corresponding to the current position. - # This effectively replaces one token with its associated num_image_patches embeddings. - final_embeddings.append(image_embeddings[i : i + num_image_patches + 1]) - - # This creates a final embeding in shape (1, num_image_patches*num_images + sequence_len, embed_dim) representing the merged sequence of text and image embeddings. - final_embeddings = mx.concatenate(final_embeddings, axis=0).reshape( - 1, -1, embed_dim - ) - - return final_embeddings + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) @@ -151,38 +139,41 @@ def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): return logits, cache @staticmethod - def from_pretrained(path: str): - path = get_model_path(path) - path = Path(path) + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) with open(path / "config.json", "r") as f: model_config = json.load(f) model_config = LlaVAConfig.from_dict(model_config) - if isinstance(model_config.vision_config, dict): - model_config.vision_config = VisionConfig.from_dict( - model_config.vision_config - ) - - if isinstance(model_config.text_config, dict): - model_config.text_config = TextConfig.from_dict(model_config.text_config) + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) model = LlavaModel(model_config) weight_files = glob.glob(str(path / "*.safetensors")) if not weight_files: - logging.error(f"No safetensors found in {path}") raise FileNotFoundError(f"No safetensors found in {path}") weights = {} for wf in weight_files: weights.update(mx.load(wf)) - if hasattr(VisionModel, "sanitize"): - weights = VisionModel.sanitize(weights) - - if hasattr(VisionModel, "sanitize"): - weights = LanguageModel.sanitize(weights) + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) model.load_weights(list(weights.items())) return model diff --git a/llava/test.py b/llava/test.py index 278ada200..3cb2863c2 100644 --- a/llava/test.py +++ b/llava/test.py @@ -7,7 +7,6 @@ import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration -from utils import get_model_path from llava import LlavaModel @@ -17,8 +16,7 @@ def load_mlx_models(path): - model_path = get_model_path(path) - model = LlavaModel.from_pretrained(model_path) + model = LlavaModel.from_pretrained(path) model.eval() return model @@ -76,7 +74,7 @@ def test_image_features(self): ) -class TestLlaVA(unittest.TestCase): +class TestLlava(unittest.TestCase): @classmethod def setUpClass(cls): cls.mx_llava = load_mlx_models(MODEL_PATH) diff --git a/llava/utils.py b/llava/utils.py deleted file mode 100644 index 265aa11f1..000000000 --- a/llava/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import os -from pathlib import Path - -import mlx.core as mx -import requests -from huggingface_hub import snapshot_download -from PIL import Image - - -def get_model_path(path_or_hf_repo: str) -> Path: - """ - Ensures the model is available locally. If the path does not exist locally, - it is downloaded from the Hugging Face Hub. - - Args: - path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. - - Returns: - Path: The path to the model. - """ - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - return model_path - - -def load_image(image_source): - """ - Helper function to load an image from either - a URL or file. - """ - if image_source.startswith(("http://", "https://")): - try: - response = requests.get(image_source, stream=True) - response.raise_for_status() - return Image.open(response.raw) - except Exception as e: - raise ValueError( - f"Failed to load image from URL: {image_source} with error {e}" - ) - elif os.path.isfile(image_source): - try: - return Image.open(image_source) - except IOError as e: - raise ValueError(f"Failed to load image {image_source} with error: {e}") - else: - raise ValueError( - f"The image {image_source} must be a valid URL or existing file." - ) - - -def prepare_inputs(processor, image, prompt): - inputs = processor(prompt, image, return_tensors="np") - pixel_values = mx.array(inputs["pixel_values"]) - input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values