Skip to content

Commit

Permalink
[Core] Dynamic image size support for VLMs (vllm-project#5276)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: ywang96 <ywang@roblox.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
  • Loading branch information
5 people authored Jul 3, 2024
1 parent 482045e commit 9831aec
Show file tree
Hide file tree
Showing 38 changed files with 1,455 additions and 666 deletions.
2 changes: 1 addition & 1 deletion docs/source/dev/input_processing/model_inputs_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Input Processing
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors.

Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
Currently, this mechanism is only utilized in :ref:`multi-modal models <multi_modality>` for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.

Guides
Expand Down
124 changes: 124 additions & 0 deletions docs/source/dev/multimodal/adding_multimodal_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
.. _adding_a_new_multimodal_model:

Adding a New Multimodal Model
=============================

This document provides a high-level guide on integrating a :ref:`multi-modal model <multi_modality>` into vLLM.

.. note::
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.

.. tip::
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
We will be happy to help you out!


1. Set up the base vLLM model
-----------------------------

As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model in vLLM, but note the following:

- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.

.. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The model class does not have to be named :code:`*ForCausalLM`.
Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples.

- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter
for each input tensor that corresponds to a multi-modal input, as shown in the following example:

.. code-block:: diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
2. Register input mappers
-------------------------

For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.

.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.

.. seealso::
:ref:`input_processing_pipeline`


3. (Optional) Register dummy data
---------------------------------

During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.

.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:

- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
:ref:`input_processing_pipeline`


4. (Optional) Register input processor
--------------------------------------

Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call.
You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.

.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:

- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
:ref:`input_processing_pipeline`
18 changes: 15 additions & 3 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _multi_modality:

Multi-Modality
==============

Expand All @@ -8,12 +10,18 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
which allows you to pass in multi-modal input alongside text and token prompts.

By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.

# TODO: Add more instructions on how to do that once embeddings is in.

Guides
++++++

.. toctree::
:maxdepth: 1

adding_multimodal_model

Module Contents
+++++++++++++++

Expand All @@ -35,6 +43,10 @@ Base Classes
:members:
:show-inheritance:

.. autoclass:: vllm.multimodal.MultiModalInputs
:members:
:show-inheritance:

.. autoclass:: vllm.multimodal.MultiModalPlugin
:members:
:show-inheritance:
Expand Down
24 changes: 16 additions & 8 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ The following :ref:`engine arguments <engine_args>` are specific to VLMs:
Currently, the support for vision language models on vLLM has the following limitations:

* Only single image input is supported per text prompt.
* Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation.

We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.

Expand All @@ -42,12 +41,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
)
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.

We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.


To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:

* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

.. note::
Expand All @@ -57,8 +61,8 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS

.. code-block:: python
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Load the image using PIL.Image
image = ...
Expand All @@ -74,8 +78,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.

.. important::
We will remove the need to format image tokens in a future release. Afterwards, the input text will follow the same format as that for the original HuggingFace model.

Online OpenAI Vision API Compatible Inference
----------------------------------------------
Expand Down Expand Up @@ -103,6 +105,11 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
--chat-template template_llava.jinja
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.

We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.

To consume the server, you can use the OpenAI client like in the example below:
Expand All @@ -121,6 +128,8 @@ To consume the server, you can use the OpenAI client like in the example below:
messages=[{
"role": "user",
"content": [
# NOTE: The prompt formatting with the image token `<image>` is not needed
# since the prompt will be processed automatically by the API server.
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
Expand All @@ -144,5 +153,4 @@ A full code example can be found in `examples/openai_vision_api_client.py <https
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
.. note::
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
processed automatically by the server.
There is no need to format the prompt in the API request since it will be handled by the server.
3 changes: 1 addition & 2 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def run_llava():
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"

image = Image.open("images/stop_sign.jpg")

Expand Down
11 changes: 3 additions & 8 deletions examples/llava_next_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,17 @@

from vllm import LLM, SamplingParams

# Dynamic image input is currently not supported and therefore
# a fixed image input shape and its corresponding feature size is required.
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
# configuration matrix.


def run_llava_next():
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
# Use the maximum possible value for memory profiling
image_feature_size=2928,
)

prompt = "[INST] " + "<image>" * 1176 + (
"\nWhat is shown in this image? [/INST]")
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
Expand Down
8 changes: 5 additions & 3 deletions examples/phi3v_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from vllm import LLM, SamplingParams

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them


def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"
Expand All @@ -18,16 +21,15 @@ def run_phi3v():
trust_remote_code=True,
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
# Use the maximum possible value for memory profiling
image_feature_size=2653,
max_num_seqs=5,
)

image = Image.open("images/cherry_blossom.jpg")

# single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")

sampling_params = SamplingParams(temperature=0, max_tokens=64)

outputs = llm.generate(
Expand Down
Loading

0 comments on commit 9831aec

Please sign in to comment.