Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlaVA in MLX #461

Merged
merged 35 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4d964bd
add: llava mlx first draft
nkasmanoff Feb 19, 2024
0e2a054
add: weights comparision
nkasmanoff Feb 19, 2024
6e4a7ee
add forward pass skeleton
nkasmanoff Feb 19, 2024
ed9d376
update: now imports weights correctly
nkasmanoff Feb 22, 2024
b83b1e5
delete base
nkasmanoff Feb 22, 2024
6e23847
latest
nkasmanoff Feb 22, 2024
bb5b898
adding config
nkasmanoff Feb 22, 2024
95f9df1
fix: use config
nkasmanoff Feb 22, 2024
a1c6fe6
add mlx config
nkasmanoff Feb 22, 2024
cec0639
feat: add image processor for llava processor
mzbac Feb 23, 2024
4dd8bca
wip
mzbac Feb 24, 2024
c4ea94f
feat: llava working example
mzbac Feb 24, 2024
b9aeade
chore: refactor generate script
mzbac Feb 24, 2024
d8f7b89
chore: clean up
mzbac Feb 24, 2024
7fb1a39
Merge pull request #1 from mzbac/llava
nkasmanoff Feb 24, 2024
371a807
add: warning to user if no <image> token despite using one
nkasmanoff Feb 24, 2024
449f7d0
add: __call__ to LlavaModel
nkasmanoff Feb 24, 2024
a1cab2b
add: call to LlavaModel
nkasmanoff Feb 24, 2024
8e6b2f5
update fp
nkasmanoff Feb 26, 2024
823411c
clean up var names
nkasmanoff Feb 26, 2024
6bc06c8
update: native GeLU
nkasmanoff Feb 26, 2024
feec5ec
Cleanup
nkasmanoff Feb 28, 2024
d76fd40
update generate and readme
nkasmanoff Feb 28, 2024
49f928a
remove todo comment
nkasmanoff Feb 28, 2024
c2b8463
rearrange tests
nkasmanoff Feb 28, 2024
25a65cf
fix example code
nkasmanoff Feb 28, 2024
c2c9411
nits in README
awni Feb 28, 2024
8301c43
update readme
nkasmanoff Feb 28, 2024
5c8f67d
nit in readme
awni Feb 28, 2024
cd77bcf
nits in README
awni Feb 28, 2024
b39c251
chore(llava): refactor image embedding merging logic
mzbac Feb 28, 2024
935ebb5
min mlx version
awni Mar 1, 2024
683b7c4
nits in readmes
awni Mar 1, 2024
b37891d
fix cli prompt, some nits
awni Mar 1, 2024
7ace6ea
updates, slight simplify
awni Mar 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llava/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.ipynb
41 changes: 41 additions & 0 deletions llava/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# LLaVA

An example of LLaVA: Large Language and Vission Assistant in MLX. LLlava is a multi-modal model that can generate text from images and text prompts. [^1]

## Setup:

Install the dependencies:

```bash
pip install -r requirements.txt
```

## Run
nkasmanoff marked this conversation as resolved.
Show resolved Hide resolved

You can use LlaVA model to ask questions about images.

The python snippet below shows how to use the model to ask questions about an image.

```python
from llava import LlavaModel
from transformers import AutoProcessor
from utils import load_image, prepare_inputs
from generate import generate_text
model_path = 'llava-hf/llava-1.5-7b-hf'

processor = AutoProcessor.from_pretrained(model_path)
model = LlavaModel.from_pretrained(model_path)

max_tokens, temperature = 128, 0.

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = load_image(image)
input_ids, pixel_values = prepare_inputs(processor, image, prompt)

reply = generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature)

print(reply)
```

[^1]: Please refer to original LlaVA library for more details: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA)
96 changes: 96 additions & 0 deletions llava/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse


import mlx.core as mx
from transformers import AutoProcessor
from utils import get_model_path, load_image, prepare_inputs

from llava import LlavaModel


def parse_arguments():
parser = argparse.ArgumentParser(
description="Generate text from an image using a model."
)
parser.add_argument(
"--model",
type=str,
default="llava-hf/llava-1.5-7b-hf",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
type=str,
default="http://images.cocodataset.org/val2017/000000039769.jpg",
help="URL or path of the image to process.",
)
parser.add_argument(
"--prompt",
type=str,
default="USER: <image>\nWhat are these?\nASSISTANT:",
help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
return parser.parse_args()


def initialize_model(model_path):
processor = AutoProcessor.from_pretrained(model_path)

model = LlavaModel.from_pretrained(get_model_path(model_path))
return processor, model


def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))


def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):

logits, cache = model(
input_ids, pixel_values
)
logits = logits[:, -1, :]
y = sample(logits, temperature=temperature)
tokens = [y.item()]

for _ in range(max_tokens):
logits, cache = model.language_model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature)
token = y.item()
if token == processor.tokenizer.eos_token_id:
break
tokens.append(token)

return processor.tokenizer.decode(tokens)


def main():
args = parse_arguments()
raw_image = load_image(args.image)
if raw_image is None:
return

processor, model = initialize_model(args.model)
input_ids, pixel_values = prepare_inputs(processor, raw_image, args.prompt)
print(args.prompt)
generated_text = generate_text(
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
)
print(generated_text)


if __name__ == "__main__":
main()
238 changes: 238 additions & 0 deletions llava/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import inspect
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn


@dataclass
class TextConfig:
model_type: str
hidden_size: int = 4096
num_hidden_layers: int = 32
intermediate_size: int = 11008
num_attention_heads: int = 32
rms_norm_eps: float = 1e-6
vocab_size: int = 32000
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None

@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(
f"rope_scaling must contain keys {required_keys}")

if self.rope_scaling["type"] != "linear":
raise ValueError(
"rope_scaling 'type' currently only supports 'linear'")


class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps

def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)

def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output


class Attention(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()

dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads

self.repeats = n_heads // n_kv_heads

head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

rope_scale = (
1 / config.rope_scaling["factor"]
if config.rope_scaling is not None and config.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=rope_scale,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(
B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)

scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32),
axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)

def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.config = config

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache


class Llama(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(
h.shape[1])
mask = mask.astype(h.dtype)

if cache is None:
cache = [None] * len(self.layers)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])

return self.norm(h), cache


class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type != "llama":
raise ValueError(
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
)
self.model = Llama(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the implementation is copied from mlx-lm's llama, with only updates made to the forward pass to allow for directly passing inputs_embeds for the initial prompt evaluation.

):
out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache

@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

@property
def layers(self):
return self.model.layers
Loading