Skip to content

Commit

Permalink
Merge branch 'main' into pc/llama3.2-vision
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy authored Oct 17, 2024
2 parents b6e6d51 + 295d6fc commit 46a1a02
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
4 changes: 2 additions & 2 deletions mlx_vlm/models/paligemma/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def __call__(
h = inputs_embeds

h *= self.config.hidden_size**0.5

mask = create_attention_mask(h)
if mask is None or cache[0].offset > 0:
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
42 changes: 23 additions & 19 deletions mlx_vlm/models/paligemma/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional
from typing import Optional

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download

from .language import LanguageModel, TextConfig
Expand Down Expand Up @@ -93,7 +92,7 @@ def _prepare_inputs_for_multimodal(

batch_size, sequence_length = input_ids.shape
scaled_image_features = image_features / (self.config.hidden_size**0.5)
final_embedding = np.zeros((batch_size, sequence_length, embed_dim))
final_embedding = mx.zeros((batch_size, sequence_length, embed_dim))

text_mask = (input_ids != self.config.image_token_index) & (
input_ids != self.config.pad_token_id
Expand All @@ -102,32 +101,37 @@ def _prepare_inputs_for_multimodal(
pad_mask = input_ids == self.config.pad_token_id

# expand masks to match embedding dimension
text_mask_expanded = np.expand_dims(text_mask, -1).repeat(embed_dim, axis=-1)
pad_mask_expanded = np.expand_dims(pad_mask, -1).repeat(embed_dim, axis=-1)
text_mask_expanded = mx.expand_dims(text_mask, -1)
text_mask_expanded = mx.repeat(text_mask_expanded, embed_dim, axis=-1)
pad_mask_expanded = mx.expand_dims(pad_mask, -1)
pad_mask_expanded = mx.repeat(pad_mask_expanded, embed_dim, axis=-1)

# insert padding and text token embeddings
final_embedding = np.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = np.where(
pad_mask_expanded, np.zeros_like(final_embedding), final_embedding
final_embedding = mx.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = mx.where(
pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding
)
pad_size = final_embedding.shape[1] - scaled_image_features.shape[1]
scaled_image_features = mx.pad(
scaled_image_features, ((0, 0), (0, pad_size), (0, 0))
)

# insert image embeddings - the image mask is always less or equal to the sentence in length
image_mask_expanded = np.expand_dims(image_mask, -1).repeat(embed_dim, axis=-1)
final_embedding[image_mask_expanded] = scaled_image_features.flatten()
image_mask_expanded = mx.expand_dims(image_mask, -1)
image_mask_expanded = mx.repeat(image_mask_expanded, embed_dim, axis=-1)
final_embedding = mx.where(
image_mask_expanded, scaled_image_features, final_embedding
)

final_embedding = np.where(
pad_mask_expanded, np.zeros_like(final_embedding), final_embedding
final_embedding = mx.where(
pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding
)

attention_mask_expanded_1 = np.expand_dims(attention_mask, 1)
attention_mask_expanded_2 = np.expand_dims(attention_mask, 2)
attention_mask_expanded_1 = mx.expand_dims(attention_mask, 1)
attention_mask_expanded_2 = mx.expand_dims(attention_mask, 2)
final_attention_mask_4d = attention_mask_expanded_1 * attention_mask_expanded_2
final_attention_mask_4d = final_attention_mask_4d
final_attention_mask_4d = np.expand_dims(final_attention_mask_4d, 1).repeat(
self.config.text_config.num_key_value_heads, axis=1
)
final_attention_mask_4d = mx.expand_dims(final_attention_mask_4d, 1)
final_embedding = mx.array(final_embedding)
final_attention_mask_4d = mx.array(final_attention_mask_4d)
return final_embedding, final_attention_mask_4d

def __call__(
Expand Down

0 comments on commit 46a1a02

Please sign in to comment.