Skip to content

Latest commit

 

History

History
405 lines (321 loc) · 25 KB

E2E_Phi-3-Embedding_Images_with_CLIPVision.md

File metadata and controls

405 lines (321 loc) · 25 KB

Using the CLIPVisionModel to process images and generate image embeddings with Phi-3-vision

The following python sample provides the necessary functionality to process images and generate image embeddings using the CLIPVisionModel.

What is CLIP

CLIP, which stands for Contrastive Language-Image Pre-training, is a model developed by OpenAI that efficiently learns visual concepts from natural language supervision. It’s a multimodal model that combines image and text understanding in a single framework. CLIP is trained on a variety of internet-sourced images and the text found with them, learning to predict which images were paired with which texts, effectively linking the two modalities.

The model works by taking an image and a text snippet as input and then predicting the likelihood that the text is an accurate description of the image. This approach allows CLIP to handle a wide range of visual tasks, such as object recognition, classification, and even generating descriptions for images it has never seen before.

One of the key advantages of CLIP is its ability to perform “zero-shot” learning, where the model can correctly handle tasks it wasn’t explicitly trained for, simply by reading the description of the task. This is possible because of the vast amount of diverse data it has been trained on, which helps it generalize well to new tasks.

Phi-3-vision

Phi-3-vision is a 4.2B parameter multimodal model with language and vision capabilities, capable of reasoning over real-world images and digital documents, extracting and reasoning over text from images, and generating insights and answers related to charts or diagrams

Example Purpose: This example demonstrates generating image embeddings using CLIP and how it can be applied to tasks related to the Phi-3 model. It serves as a reference for comparing the performance and characteristics of different embedding techniques (CLIP vs. Phi-3). Integration Challenge: Integrating another vision encoder like CLIP directly into Phi-3 is indeed complex. This complexity arises due to architectural differences and the need for seamless integration without losing context or performance. Integration hasn't been fully evaluated or implemented yet so this included. Comparison Approach: The code aims to provide a parallel comparison rather than an integrated solution. It allows users to see how CLIP embeddings perform alongside Phi-3 embeddings, providing insights into potential benefits or drawbacks. Clarification: This Phi-3CookBook Example: Showcases how to use CLIP embeddings as a comparative tool rather than a direct integration into Phi-3. Integration Work: Full integration of CLIP embeddings into Phi-3 remains a challenge and has not been fully explored but its there for customer to experiment.

Sample Code

This code defines a class called Phi3ImageEmbedding that represents an image embedding model. The purpose of this class is to process images and generate embeddings that can be used for downstream tasks such as image classification or retrieval.

The init method initializes the model by setting up various components such as embedding dropout, image processor, HD transform parameters, and image projection. It takes a config object as input, which contains configuration parameters for the model. The wte parameter is an optional input that represents word token embeddings.

The get_img_features method takes an input tensor img_embeds representing image embeddings and returns a tensor representing the extracted image features. It uses the img_processor to process the image embeddings and extract the desired features based on the layer_idx and type_feature parameters.

Explain the Code

Let's go through the code step by step:

The code imports necessary libraries and modules, including math, torch, torch.nn, and various components from the transformers library.

The code defines a configuration object called CLIP_VIT_LARGE_PATCH14_336_CONFIG that contains various hyperparameters for the image embedding model.

The Phi3ImageEmbedding class is defined, which is a subclass of torch.nn.Module. This class represents the image embedding model and contains methods for forward propagation and setting image features.

The init method initializes the Phi3ImageEmbedding object. It takes a config object as input, which is an instance of PretrainedConfig class. It also takes an optional wte argument.

The init method initializes various attributes of the Phi3ImageEmbedding object based on the provided config object. It sets the hidden size, dropout rate, image processor, image projection, and other parameters.

The set_img_features method sets the image features for the model. It takes a tensor of image features as input and assigns it to the img_features attribute of the object.

The set_img_sizes method sets the image sizes for the model. It takes a tensor of image sizes as input and assigns it to the img_sizes attribute of the object.

The get_img_features method extracts image features from the input image embeddings. It takes a tensor of image embeddings as input and returns the extracted image features.

The forward method performs forward propagation through the model. It takes input IDs, pixel values, and image sizes as input and returns the hidden states of the model. It first checks if image features and sizes are already set, and if not, it uses the provided input to set them. Then, it processes the input IDs and extracts image features based on the configured image processor. Finally, it applies the image projection to the extracted features and returns the hidden states.

Overall, this code defines a class that represents an image embedding model and provides methods for setting image features and performing forward propagation.

Code Sample

import math
import torch
from transformers import CLIPVisionModel, PretrainedConfig
from transformers import CLIPVisionConfig 
from transformers.utils import logging
from datetime import datetime 

# Import necessary libraries
import torch.nn as nn

# Set up logging
logger = logging.get_logger(__name__)

# Define the configuration for the CLIPVisionModel
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    attention_dropout=0.0,
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    initializer_factor=1.0,
    initializer_range=0.02,
    intermediate_size=4096,
    layer_norm_eps=1e-05,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768 
)

# Define the Phi3ImageEmbedding class
class Phi3ImageEmbedding(nn.Module):
        """Phi3 Image embedding."""

        def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None:
                super().__init__()

                # Set up the embedding dropout
                hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size
                if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
                        embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop
                        self.drop = nn.Dropout(embd_drop)
                else:
                        self.drop = None

                self.wte = wte

                # Set up the image processor based on the configuration
                if isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model':
                        assert 'model_name' in config.img_processor, 'model_name must be provided for CLIPVisionModel'
                        assert 'image_dim_out' in config.img_processor, 'image_dim_out must be provided for CLIPVisionModel'
                        assert 'num_img_tokens' in config.img_processor, 'num_img_tokens must be provided for CLIPVisionModel'
                        assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336'
                        clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
                        self.img_processor = CLIPVisionModel(clip_config)
                        image_dim_out = config.img_processor['image_dim_out']
                        self.num_img_tokens = config.img_processor['num_img_tokens']
                else:
                        raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')

                self.image_dim_out = image_dim_out
                self.img_sizes = None

                # Set up the HD transform parameters
                self.use_hd_transform = kwargs.get('use_hd_transform', False)
                self.with_learnable_separator = kwargs.get('with_learnable_separator', False)
                self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
                assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value'
                if self.with_learnable_separator:
                        assert self.use_hd_transform, 'learnable separator is only for hd transform'
                        self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4]))
                        self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4]))
                        logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}')

                # Set up the image projection based on the projection_cls
                projection_cls = kwargs.get('projection_cls', 'linear')
                if projection_cls == 'linear':
                        self.img_projection = nn.Linear(image_dim_out, hidden_size)
                elif projection_cls == 'mlp' and self.use_hd_transform:
                        dim_projection = hidden_size
                        depth = 2
                        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
                        for _ in range(1, depth):
                                layers.extend([nn.GELU(),
                                                                nn.Linear(dim_projection, dim_projection)])
                        self.img_projection = nn.Sequential(*layers)
                elif projection_cls == 'mlp':
                        dim_projection = hidden_size
                        depth = 2
                        layers = [nn.Linear(image_dim_out, dim_projection)]
                        for _ in range(1, depth):
                                layers.extend([nn.GELU(),
                                                                nn.Linear(dim_projection, dim_projection)])
                        self.img_projection = nn.Sequential(*layers)
                else:
                        raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented')

                self.vocab_size = config.vocab_size
                self.img_features = None

                # Set up the layer index and type of feature for the image processor
                if isinstance(config.img_processor, dict):
                        self.layer_idx = config.img_processor.get('layer_idx', -2)
                        self.type_feature = config.img_processor.get('type_feature', 'patch')
                else:
                        self.layer_idx = -2
                        self.type_feature = 'patch'


        def set_img_features(self, img_features: torch.FloatTensor) -> None:
                self.img_features = img_features

        def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
                self.img_sizes = img_sizes

        def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
                LAYER_IDX = self.layer_idx
                TYPE_FEATURE = self.type_feature

                img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)
                img_feature = img_processor_output.hidden_states[LAYER_IDX]

                if TYPE_FEATURE == "patch":
                        patch_feature = img_feature[:, 1:]
                        return patch_feature

                if TYPE_FEATURE == "cls_patch":
                        return img_feature

                raise NotImplementedError

        def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor:

                MAX_INPUT_ID = int(1e9)
                img_embeds = pixel_values
                img_sizes = image_sizes

                if self.img_features is not None:
                        img_embeds = self.img_features.clone()
                        self.img_features = None

                if self.img_sizes is not None:
                        img_sizes = self.img_sizes

                input_shape = input_ids.size()
                input_ids = input_ids.view(-1, input_shape[-1])

                with torch.no_grad():
                        positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
                
                select = False

                if isinstance(self.img_projection, nn.Sequential):  
                        target_device = self.img_projection[0].bias.device  
                        target_dtype = self.img_projection[0].bias.dtype  
                else:  # It's a single nn.Linear layer  
                        target_device = self.img_projection.bias.device  
                        target_dtype = self.img_projection.bias.dtype  

                if len(positions.tolist()) > 0:
                        with torch.no_grad():
                                g_values = abs(input_ids[positions[:, 0], positions[:, 1]])

                        if self.use_hd_transform and img_sizes is not None and len(img_sizes):
                                hd_transform = True
                                assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform'
                                img_features = self.get_img_features(img_embeds.flatten(0, 1))
                                base_feat_height = base_feat_width = int(img_features.shape[1] ** 0.5)
                                assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform'
                                img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
                                C = self.image_dim_out
                                H = base_feat_height

                                output_imgs = []
                                output_len = []
                                if isinstance(img_sizes, torch.Tensor):
                                        img_sizes = img_sizes.view(-1, 2)
                                for _bs in range(bs):
                                        h, w = img_sizes[_bs]
                                        h = h // 336 
                                        w = w // 336
                                        B_ = h * w

                                        global_img_feature = img_features[_bs, :1]
                                        glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
                                        temp_glb_GN = self.sub_GN.repeat(1, H//2, 1, 1)
                                        glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
                                        sub_img = img_features[_bs, 1:]
                                        sub_img = sub_img[:B_]
                                        sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
                                        sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
                                        temp_sub_GN = self.sub_GN.repeat(1, h*12, 1, 1)
                                        sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)
                                        if self.hd_transform_order == 'glb_sub':
                                                output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
                                        elif self.hd_transform_order == 'sub_glb':
                                                output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
                                        else:
                                                raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
                                        temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
                                        assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
                                        output_len.append(temp_len)
                                
                                num_img_tokens = output_len
                                img_set_tensor = []
                                for _output_img in output_imgs:
                                        img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
                                        img_set_tensor.append(img_feature_proj)
                                logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
                        elif img_embeds.ndim == 4:
                                selected_g_values = g_values[::self.num_img_tokens]
                                assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}'
                                start_time = datetime.now()
                                tt = (
                                        self.get_img_features(img_embeds)
                                        .to(target_device)
                                        .to(target_dtype)
                                        .reshape(-1, self.image_dim_out)
                                )
                                logger.info(f'img_embeds size: {img_embeds.size()}, loading time {datetime.now() - start_time}')
                                img_set_tensor = self.img_projection(tt)
                        elif img_embeds.ndim == 3:
                                selected_g_values = g_values[::self.num_img_tokens]
                                assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}'
                                tt = (
                                        img_embeds
                                        .to(target_device)
                                        .to(target_dtype)
                                        .view(-1, self.image_dim_out)
                                )
                                img_set_tensor = self.img_projection(tt)
                        else:
                                raise NotImplementedError
                        select = True
                
                with torch.no_grad():
                        input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
                
                hidden_states = self.wte(input_ids)

                if select:
                        if hd_transform:
                                idx = 0
                                for i, cnt in enumerate(num_img_tokens):
                                        hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
                                                img_set_tensor[i]
                                                .to(hidden_states.dtype)
                                                .to(hidden_states.device)
                                                )
                                        idx += cnt
                        else:
                                idx = 0
                                assert len(selected_g_values) * self.num_img_tokens == len(img_set_tensor), f'len(selected_g_values) * self.num_img_tokens = {len(selected_g_values) * self.num_img_tokens}, len(img_set_tensor) = {len(img_set_tensor)}'
                                for i, g in enumerate(selected_g_values):
                                        cnt = self.num_img_tokens
                                        hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
                                                img_set_tensor[i * cnt : (i + 1) * cnt]
                                                .to(hidden_states.dtype)
                                                .to(hidden_states.device)
                                                )
                                        idx += cnt

                if self.drop is not None:
                        hidden_states = self.drop(hidden_states)

                return hidden_states

Building your Pipeline

Working with code that generates embeddings, such as the example above, you typically integrate it into your pipeline depending on your specific use case.

  1. Loading Pre-trained Models: If you're loading pre-trained models from Hugging Face, these models are indeed binary. You can use them directly for generating embeddings without additional training. This is useful for tasks like feature extraction or semantic search where you need embeddings out-of-the-box.

  2. Fine-Tuning Pipeline: If you need to adapt the model to a specific task or dataset, you would integrate the code into a fine-tuning pipeline. This involves:

    • Loading the Pre-trained Model: Start with a pre-trained model from Hugging Face.
    • Preparing Your Dataset: Ensure your dataset is in the correct format for training.
    • Fine-Tuning: Use libraries like transformers and datasets from Hugging Face to fine-tune the model on your dataset. This step adjusts the model weights to better suit your specific task.

For example, in the context of the Phi-3 Cookbook and CLIPVision, you might:

  • Generate Embeddings: Use the pre-trained CLIP model to generate embeddings for images.
  • Fine-Tune: If the embeddings need to be more specific to your application, fine-tune the CLIP model on a dataset relevant to your use case.

Here's a simplified example of how you might integrate this in code:

from transformers import CLIPProcessor, CLIPModel
import torch
 
# Load pre-trained model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
# Prepare your data
images = [...]  # List of images
inputs = processor(images=images, return_tensors="pt")
 
# Generate embeddings
with torch.no_grad():
    embeddings = model.get_image_features(**inputs)
 
# Fine-tuning (if needed)
# Define your fine-tuning logic here

This approach allows you to leverage powerful pre-trained models and adapt them to your specific needs.

Integrating the Phi Family of Models

Integrating the Phi-3 model with the provided code example involving CLIP can indeed be challenging, especially when considering different vision encoders.

Here's a brief overview of how you might approach this:

Key Points

Data Processing: Ensure that the images are processed in a way that fits the input requirements of the Phi-3 model. Embedding Generation: Replace the CLIP embedding generation with the corresponding method from your Phi-3 model. Fine-Tuning: If you need to fine-tune the Phi-3 model, ensure the logic is included after generating the embeddings.

Steps to Integrate Phi-3 Model

Load the Phi-3 Model: Assuming you have a Phi3Model class for the vanilla or fine-tuned Phi-3 model. Modify the Data Preparation: Adjust the data preparation to suit the input requirements of the Phi-3 model. Integrate Phi-3 Embeddings: Replace the part where CLIP embeddings are generated with the Phi-3 model's embedding generation.

from transformers import CLIPProcessor, CLIPModel
import torch
 
# Load pre-trained CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
# Load Phi-3 model (vanilla or fine-tuned)
# Assuming you have a load_phi3_model function to load your Phi-3 model
phi3_model = load_phi3_model(fine_tuned=True)
 
# Prepare your data
images = [...]  # List of images
inputs = clip_processor(images=images, return_tensors="pt")
 
# Generate embeddings using CLIP (for comparison)
with torch.no_grad():
    clip_embeddings = clip_model.get_image_features(**inputs)
 
# Generate embeddings using Phi-3
# Adjust this part according to how your Phi-3 model processes inputs
phi3_inputs = process_for_phi3_model(images)
with torch.no_grad():
    phi3_embeddings = phi3_model.get_image_features(phi3_inputs)
 
# Fine-tuning or further processing (if needed)
# Define your fine-tuning logic here
``