diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst index 594edeb746bb..2dde251aa144 100644 --- a/docs/source/dev/input_processing/model_inputs_index.rst +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -8,7 +8,7 @@ Input Processing vLLM provides a mechanism for defining input processors for each model so that the inputs are processed in :class:`~vllm.LLMEngine` before they are passed to model executors. -Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input +Currently, this mechanism is only utilized in :ref:`multi-modal models ` for preprocessing multi-modal input data in addition to input prompt, but it can be extended to text-only language models when needed. Guides diff --git a/docs/source/dev/multimodal/adding_multimodal_model.rst b/docs/source/dev/multimodal/adding_multimodal_model.rst new file mode 100644 index 000000000000..0e9590639b22 --- /dev/null +++ b/docs/source/dev/multimodal/adding_multimodal_model.rst @@ -0,0 +1,124 @@ +.. _adding_a_new_multimodal_model: + +Adding a New Multimodal Model +============================= + +This document provides a high-level guide on integrating a :ref:`multi-modal model ` into vLLM. + +.. note:: + The complexity of adding a new model depends heavily on the model's architecture. + The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. + However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. + +.. tip:: + If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub `_ repository. + We will be happy to help you out! + + +1. Set up the base vLLM model +----------------------------- + +As usual, follow :ref:`these steps ` to implement the model in vLLM, but note the following: + +- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface. + + .. code-block:: diff + + + from vllm.model_executor.models.interfaces import SupportsVision + + - class YourModelForImage2Seq(nn.Module): + + class YourModelForImage2Seq(nn.Module, SupportsVision): + + .. note:: + The model class does not have to be named :code:`*ForCausalLM`. + Check out `the HuggingFace Transformers documentation `__ for some examples. + +- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter + for each input tensor that corresponds to a multi-modal input, as shown in the following example: + + .. code-block:: diff + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + + pixel_values: torch.Tensor, + ) -> SamplerOutput: + + +2. Register input mappers +------------------------- + +For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper `. +This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`. + +.. code-block:: diff + + from vllm.model_executor.models.interfaces import SupportsVision + + from vllm.multimodal import MULTIMODAL_REGISTRY + + + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() + + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() + class YourModelForImage2Seq(nn.Module, SupportsVision): + +A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. + +.. seealso:: + :ref:`input_processing_pipeline` + + +3. (Optional) Register dummy data +--------------------------------- + +During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. +In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data `. + +.. code-block:: diff + + from vllm.inputs import INPUT_REGISTRY + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY + + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() + + @INPUT_REGISTRY.register_dummy_data() + class YourModelForImage2Seq(nn.Module, SupportsVision): + +Here are some examples: + +- Image inputs (static feature size): `LLaVA-1.5 Model `__ +- Image inputs (dynamic feature size): `LLaVA-NeXT Model `__ + +.. seealso:: + :ref:`input_processing_pipeline` + + +4. (Optional) Register input processor +-------------------------------------- + +Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor. +This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call. +You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor `. + +.. code-block:: diff + + from vllm.inputs import INPUT_REGISTRY + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY + + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() + @INPUT_REGISTRY.register_dummy_data() + + @INPUT_REGISTRY.register_input_processor() + class YourModelForImage2Seq(nn.Module, SupportsVision): + +A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. +Here are some examples: + +- Insert static number of image tokens: `LLaVA-1.5 Model `__ +- Insert dynamic number of image tokens: `LLaVA-NeXT Model `__ + +.. seealso:: + :ref:`input_processing_pipeline` diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 4d5fb3246b68..d01f39284377 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -1,3 +1,5 @@ +.. _multi_modality: + Multi-Modality ============== @@ -8,12 +10,18 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm :class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` which allows you to pass in multi-modal input alongside text and token prompts. -By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, -you must decorate the model class with :meth:`InputRegistry.register_dummy_data `, -as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper ` for each modality type to support. +By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. `. # TODO: Add more instructions on how to do that once embeddings is in. +Guides +++++++ + +.. toctree:: + :maxdepth: 1 + + adding_multimodal_model + Module Contents +++++++++++++++ @@ -35,6 +43,10 @@ Base Classes :members: :show-inheritance: +.. autoclass:: vllm.multimodal.MultiModalInputs + :members: + :show-inheritance: + .. autoclass:: vllm.multimodal.MultiModalPlugin :members: :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 053f5b8609ce..f8c61018a08d 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -23,7 +23,6 @@ The following :ref:`engine arguments ` are specific to VLMs: Currently, the support for vision language models on vLLM has the following limitations: * Only single image input is supported per text prompt. - * Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation. We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub `_ if you have any feedback or feature requests. @@ -42,12 +41,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` ) .. important:: + Currently, you have to specify ``image_feature_size`` to support memory profiling. + To avoid OOM during runtime, you should set this to the maximum value supported by the model. + The calculation of feature size is specific to the model. For more details, please refer to + the function :code:`get__image_feature_size` inside the corresponding model file. + We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: -* ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. +* ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. .. note:: @@ -57,8 +61,8 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS .. code-block:: python - prompt = "" * 576 + ( - "\nUSER: What is the content of this image?\nASSISTANT:") + # Refer to the HuggingFace repo for the correct format to use + prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" # Load the image using PIL.Image image = ... @@ -74,8 +78,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS A code example can be found in `examples/llava_example.py `_. -.. important:: - We will remove the need to format image tokens in a future release. Afterwards, the input text will follow the same format as that for the original HuggingFace model. Online OpenAI Vision API Compatible Inference ---------------------------------------------- @@ -103,6 +105,11 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with --chat-template template_llava.jinja .. important:: + Currently, you have to specify ``image_feature_size`` to support memory profiling. + To avoid OOM during runtime, you should set this to the maximum value supported by the model. + The calculation of feature size is specific to the model. For more details, please refer to + the function :code:`get__image_feature_size` inside the corresponding model file. + We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration. To consume the server, you can use the OpenAI client like in the example below: @@ -121,6 +128,8 @@ To consume the server, you can use the OpenAI client like in the example below: messages=[{ "role": "user", "content": [ + # NOTE: The prompt formatting with the image token `` is not needed + # since the prompt will be processed automatically by the API server. {"type": "text", "text": "What's in this image?"}, { "type": "image_url", @@ -144,5 +153,4 @@ A full code example can be found in `examples/openai_vision_api_client.py .. note:: - The prompt formatting with the image token ```` is not needed when serving VLMs with the API server since the prompt will be - processed automatically by the server. + There is no need to format the prompt in the API request since it will be handled by the server. diff --git a/examples/llava_example.py b/examples/llava_example.py index 7f3d84f99f76..f5cb2a661e83 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -17,8 +17,7 @@ def run_llava(): image_feature_size=576, ) - prompt = "" * 576 + ( - "\nUSER: What is the content of this image?\nASSISTANT:") + prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" image = Image.open("images/stop_sign.jpg") diff --git a/examples/llava_next_example.py b/examples/llava_next_example.py index 3c39590e7fb8..20d4791ffaf9 100644 --- a/examples/llava_next_example.py +++ b/examples/llava_next_example.py @@ -5,22 +5,17 @@ from vllm import LLM, SamplingParams -# Dynamic image input is currently not supported and therefore -# a fixed image input shape and its corresponding feature size is required. -# See https://github.com/vllm-project/vllm/pull/4199 for the complete -# configuration matrix. - def run_llava_next(): llm = LLM( model="llava-hf/llava-v1.6-mistral-7b-hf", image_token_id=32000, image_input_shape="1,3,336,336", - image_feature_size=1176, + # Use the maximum possible value for memory profiling + image_feature_size=2928, ) - prompt = "[INST] " + "" * 1176 + ( - "\nWhat is shown in this image? [/INST]") + prompt = "[INST] \nWhat is shown in this image? [/INST]" url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg" image = Image.open(BytesIO(requests.get(url).content)) sampling_params = SamplingParams(temperature=0.8, diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index 7d6c58d7fcd8..0aabfee6ab63 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -5,6 +5,9 @@ from vllm import LLM, SamplingParams +# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. +# You can use `.buildkite/download-images.sh` to download them + def run_phi3v(): model_path = "microsoft/Phi-3-vision-128k-instruct" @@ -18,7 +21,8 @@ def run_phi3v(): trust_remote_code=True, image_token_id=32044, image_input_shape="1,3,1008,1344", - image_feature_size=1921, + # Use the maximum possible value for memory profiling + image_feature_size=2653, max_num_seqs=5, ) @@ -26,8 +30,6 @@ def run_phi3v(): # single-image prompt prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501 - prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "") - sampling_params = SamplingParams(temperature=0, max_tokens=64) outputs = llm.generate( diff --git a/tests/conftest.py b/tests/conftest.py index fd088d566d7a..608a5f49d593 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,13 @@ import contextlib import gc import os +import sys from collections import UserList from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, - TypedDict, TypeVar) +from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict, + TypeVar) import pytest import torch @@ -22,13 +23,10 @@ destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.multimodal.utils import fetch_image from vllm.sequence import SampleLogprobs from vllm.utils import cuda_device_count_stateless, is_cpu -if TYPE_CHECKING: - # it will call torch.cuda.device_count() - from vllm.multimodal import MultiModalDataDict - logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -47,30 +45,42 @@ def _read_prompts(filename: str) -> List[str]: @dataclass(frozen=True) class ImageAsset: - name: Literal["stop_sign", "cherry_blossom"] + name: Literal["stop_sign", "cherry_blossom", "boardwalk"] @cached_property def pil_image(self) -> Image.Image: - return Image.open(_IMAGE_DIR / f"{self.name}.jpg") - - def for_hf(self) -> Image.Image: - return self.pil_image + if self.name == "boardwalk": + return fetch_image( + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + ) - def for_vllm(self) -> Dict[str, Any]: - return {"image": self.pil_image} + return Image.open(_IMAGE_DIR / f"{self.name}.jpg") class _ImageAssetPrompts(TypedDict): stop_sign: str cherry_blossom: str + boardwalk: str + + +if sys.version_info < (3, 9): + # UserList cannot be subscripted + class _ImageAssetsBase(UserList): + pass +else: + class _ImageAssetsBase(UserList[ImageAsset]): + pass -class _ImageAssets(UserList): + +class _ImageAssets(_ImageAssetsBase): def __init__(self) -> None: - super().__init__( - [ImageAsset("stop_sign"), - ImageAsset("cherry_blossom")]) + super().__init__([ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ImageAsset("boardwalk") + ]) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -79,7 +89,10 @@ def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: The order of the returned prompts matches the order of the assets when iterating through this object. """ - return [prompts["stop_sign"], prompts["cherry_blossom"]] + return [ + prompts["stop_sign"], prompts["cherry_blossom"], + prompts["boardwalk"] + ] IMAGE_ASSETS = _ImageAssets() @@ -220,7 +233,7 @@ def generate( self, prompts: List[str], images: Optional[List[Image.Image]] = None, - **kwargs, + **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: if images: assert len(prompts) == len(images) @@ -255,7 +268,7 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[List[Image.Image]] = None, - **kwargs, + **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, @@ -291,19 +304,30 @@ def generate_greedy_logprobs( self, prompts: List[str], max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, ) -> List[List[torch.Tensor]]: - all_logprobs = [] - for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + all_logprobs: List[List[torch.Tensor]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + output = self.model.generate( - self.wrap_device(input_ids), + **self.wrap_device(inputs), use_cache=True, do_sample=False, max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, + **kwargs, ) - seq_logprobs = [] + seq_logprobs: List[torch.Tensor] = [] for hidden_states in output.hidden_states: last_hidden_states = hidden_states[-1][0] logits = torch.matmul( @@ -323,20 +347,32 @@ def generate_greedy_logprobs_limit( prompts: List[str], max_tokens: int, num_logprobs: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] - for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + input_ids = inputs.input_ids + output = self.model.generate( - self.wrap_device(input_ids), + **self.wrap_device(inputs), use_cache=True, do_sample=False, max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, + **kwargs, ) seq_logprobs: List[torch.Tensor] = [] @@ -431,7 +467,7 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[List["MultiModalDataDict"]] = None, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -439,7 +475,7 @@ def generate( inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): - inputs[i]["multi_modal_data"] = image + inputs[i]["multi_modal_data"] = {"image": image} req_outputs = self.model.generate(inputs, sampling_params=sampling_params) @@ -462,10 +498,19 @@ def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None - req_outputs = self.model.generate(prompts, + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] for req_output in req_outputs: @@ -480,7 +525,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List["MultiModalDataDict"]] = None, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) @@ -492,11 +537,14 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) - outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params) + outputs = self.generate_w_logprobs(prompts, + greedy_logprobs_params, + images=images) return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index 41c3fd9e7f6b..1d143a8526f4 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -30,9 +30,10 @@ @pytest.mark.parametrize("tensor_parallel_size", [2]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, - tensor_parallel_size: int, dtype: str, - max_tokens: int) -> None: + tensor_parallel_size: int, dtype: str, max_tokens: int, + num_logprobs: int) -> None: if cuda_device_count_stateless() < tensor_parallel_size: pytest.skip( f"Need at least {tensor_parallel_size} GPUs to run the test.") @@ -44,8 +45,10 @@ def test_models(hf_runner, vllm_runner, image_assets, vllm_runner, image_assets, model_and_config=model_and_vl_config[0], + size_factors=[1.0], dtype=dtype, max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, ) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index c6313c52e4e3..2f4b85bc1617 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -4,18 +4,21 @@ from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_outputs_equal +from .utils import check_logprobs_close pytestmark = pytest.mark.vlm -# The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - "\nUSER: What's the content of the image?\nASSISTANT:", + "USER: \nWhat's the content of the image?\nASSISTANT:", "cherry_blossom": - "\nUSER: What is the season?\nASSISTANT:", + "USER: \nWhat is the season?\nASSISTANT:", + "boardwalk": + "USER: \nWhat's in this image?\nASSISTANT:", }) @@ -37,27 +40,34 @@ def iter_llava_configs(model_name: str): ] -def vllm_to_hf_output(vllm_output: Tuple[List[int], str], +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], vlm_config: VisionLanguageConfig, model_id: str): """Sanitize vllm output to be comparable with hf output. The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... It also reduces `output_str` from "bla" to "bla". """ - output_ids, output_str = vllm_output + output_ids, output_str, out_logprobs = vllm_output image_token_id = vlm_config.image_token_id tokenizer = AutoTokenizer.from_pretrained(model_id) image_token_str = tokenizer.decode(image_token_id) + eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] + hf_output_str = output_str \ .replace(image_token_str * vlm_config.image_feature_size, "") + assert hf_output_str[0] == " " + hf_output_str = hf_output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - return hf_output_ids, hf_output_str + return hf_output_ids, hf_output_str, out_logprobs def run_test( @@ -66,8 +76,10 @@ def run_test( image_assets: _ImageAssets, model_and_config: Tuple[str, VisionLanguageConfig], *, + size_factors: List[float], dtype: str, max_tokens: int, + num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -81,61 +93,85 @@ def run_test( The text output is sanitized to be able to compare with hf. """ model_id, vlm_config = model_and_config - hf_images = [asset.for_hf() for asset in image_assets] + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size with vllm_runner(model_id, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True, **vlm_config.as_cli_args_dict()) as vllm_model: - - # NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` - # we must put it inside the vllm_runner context manager - # i.e. after creating vLLM instance. - vllm_images = [asset.for_vllm() for asset in image_assets] - - vllm_image_prompts = [ - p.replace("", "" * vlm_config.image_feature_size) - for p in HF_IMAGE_PROMPTS + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image ] - vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, - max_tokens, - images=vllm_images) - with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: - hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images) - - check_outputs_equal( - hf_outputs, - [ - vllm_to_hf_output(vllm_output, vlm_config, model_id) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + # TODO: Check whether using original CLIPVisionModel can improve + # consistency against HF + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model_and_config, - dtype: str, max_tokens: int) -> None: + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: run_test( hf_runner, vllm_runner, image_assets, model_and_config, + size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=1, ) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index e9babba13c47..8817f41a62f7 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,12 +1,15 @@ -from typing import List, Tuple +import re +from typing import List, Optional, Tuple import pytest from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs from ..conftest import IMAGE_ASSETS -from .utils import check_outputs_equal +from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -15,21 +18,20 @@ "The assistant gives helpful, detailed, and polite answers to the human's " "questions.") -# The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - f"{_PREFACE} \nUSER: What's the content of the image?\nASSISTANT:", + f"{_PREFACE} USER: \nWhat's the content of the image? ASSISTANT:", "cherry_blossom": - f"{_PREFACE} \nUSER: What is the season?\nASSISTANT:", + f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", + "boardwalk": + f"{_PREFACE} USER: \nWhat's in this image? ASSISTANT:", }) def iter_llava_next_configs(model_name: str): + # Need to use the max possible feature size for profile_run image_hw_to_feature_size = { - (336, 336): 1176, - (672, 672): 2928, - (1344, 336): 1944, - (336, 1344): 1890, + (336, 336): 2928, } for (h, w), f in image_hw_to_feature_size.items(): @@ -47,37 +49,55 @@ def iter_llava_next_configs(model_name: str): ] -def vllm_to_hf_output(vllm_output: Tuple[List[int], str], +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], vlm_config: VisionLanguageConfig, model_id: str): """Sanitize vllm output to be comparable with hf output. The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... It also reduces `output_str` from "bla" to "bla". """ - output_ids, output_str = vllm_output + output_ids, output_str, out_logprobs = vllm_output image_token_id = vlm_config.image_token_id tokenizer = AutoTokenizer.from_pretrained(model_id) image_token_str = tokenizer.decode(image_token_id) + eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] - hf_output_str = output_str \ - .replace(image_token_str * vlm_config.image_feature_size, " ") - return hf_output_ids, hf_output_str + hf_output_str = re.sub(fr"({image_token_str})+", "", output_str) + assert hf_output_str[0] == " " + hf_output_str = hf_output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs -@pytest.mark.xfail( - reason="Inconsistent image processor being used due to lack " - "of support for dynamic image token replacement") @pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model_and_config, - dtype: str, max_tokens: int) -> None: + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. @@ -88,37 +108,46 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config, The text output is sanitized to be able to compare with hf. """ model_id, vlm_config = model_and_config - hf_images = [asset.for_hf() for asset in image_assets] - vllm_images = [asset.for_vllm() for asset in image_assets] + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # max_model_len should be greater than image_feature_size + with vllm_runner(model_id, + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + **vlm_config.as_cli_args_dict()) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: - hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images) - - vllm_image_prompts = [ - p.replace("", "" * vlm_config.image_feature_size) - for p in HF_IMAGE_PROMPTS - ] - - with vllm_runner( - model_id, - dtype=dtype, - # should be greater than image_feature_size - max_model_len=4096, - enforce_eager=True, - **vlm_config.as_cli_args_dict(), - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, - max_tokens, - images=vllm_images) - - check_outputs_equal( - hf_outputs, - [ - vllm_to_hf_output(vllm_output, vlm_config, model_id) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + # TODO: Check whether using original CLIPVisionModel can improve + # consistency against HF + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 917bdbf94ab9..f144f97551c0 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -1,29 +1,33 @@ +import re from typing import List, Optional, Tuple, Type import pytest from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_outputs_equal +from .utils import check_logprobs_close pytestmark = pytest.mark.vlm -# The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 "cherry_blossom": - "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501 + "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", + "boardwalk": + "<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n", }) def iter_phi3v_configs(model_name: str): + # Need to use the max possible feature size for profile_run image_hw_to_feature_size = { - (1008, 1344): 1921, - (2016, 2688): 1933, + (1008, 1344): 2653, } for (h, w), f in image_hw_to_feature_size.items(): @@ -39,29 +43,29 @@ def iter_phi3v_configs(model_name: str): ] -def vllm_to_hf_output(vllm_output: Tuple[List[int], str], +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], vlm_config: VisionLanguageConfig, model_id: str): """Sanitize vllm output to be comparable with hf output. The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... It also reduces `output_str` from "bla" to "bla". """ - output_ids, output_str = vllm_output - image_token_id = vlm_config.image_token_id + output_ids, output_str, out_logprobs = vllm_output - tokenizer = AutoTokenizer.from_pretrained(model_id) - image_token_str = tokenizer.decode(image_token_id) - - hf_output_ids = [ - token_id if token_id != image_token_id else 0 - for idx, token_id in enumerate(output_ids) - ] - hf_output_str = output_str \ - .replace(image_token_str * vlm_config.image_feature_size, "") \ - .replace("", " ").replace("<|user|>", "") \ + output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str) + assert output_str_without_image[0] == " " + output_str_without_image = output_str_without_image[1:] + + hf_output_str = output_str_without_image.replace("<|user|>", "") \ .replace("<|end|>\n<|assistant|>", " ") - return hf_output_ids, hf_output_str + tokenizer = AutoTokenizer.from_pretrained(model_id) + hf_output_ids = tokenizer.encode(output_str_without_image) + assert hf_output_ids[0] == 1 + hf_output_ids = hf_output_ids[1:] + + return hf_output_ids, hf_output_str, out_logprobs target_dtype = "half" @@ -75,8 +79,10 @@ def run_test( image_assets: _ImageAssets, model_and_config: Tuple[str, VisionLanguageConfig], *, + size_factors: List[float], dtype: str, max_tokens: int, + num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -90,73 +96,91 @@ def run_test( The text output is sanitized to be able to compare with hf. """ model_id, vlm_config = model_and_config - hf_images = [asset.for_hf() for asset in image_assets] + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size with vllm_runner(model_id, - max_model_len=2048, + max_model_len=4096, dtype=dtype, tensor_parallel_size=tensor_parallel_size, - enforce_eager=True, distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, **vlm_config.as_cli_args_dict()) as vllm_model: - # NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` - # we must put it inside the vllm_runner context manager - # i.e. after creating vLLM instance. - - vllm_images = [asset.for_vllm() for asset in image_assets] - - vllm_image_prompts = [ - p.replace("<|image_1|>", - "<|image|>" * vlm_config.image_feature_size + "") - for p in HF_IMAGE_PROMPTS + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=vllm_images) + for prompts, vllm_images in inputs_per_image ] - vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, - max_tokens, - images=vllm_images) - # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} with hf_runner(model_id, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: - hf_outputs = hf_model.generate_greedy( - HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images, - eos_token_id=hf_model.processor.tokenizer.eos_token_id) - - check_outputs_equal( - hf_outputs, - [ - vllm_to_hf_output(vllm_output, vlm_config, model_id) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - + eos_token_id = hf_model.processor.tokenizer.eos_token_id + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=hf_images, + eos_token_id=eos_token_id) + for prompts, hf_images in inputs_per_image + ] -# Since we use _attn_implementation="eager" for hf_runner, here is -# numeric difference for longer context and test can't pass -@pytest.mark.xfail( - reason="Inconsistent image processor being used due to lack " - "of support for dynamic image token replacement") + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +# Since we use _attn_implementation="eager" for hf_runner, there is more +# significant numerical difference. The basic `logprobs=5` fails to pass. @pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) def test_models(hf_runner, vllm_runner, image_assets, model_and_config, - dtype: str, max_tokens: int) -> None: + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: run_test( hf_runner, vllm_runner, image_assets, model_and_config, + size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=1, ) diff --git a/tests/models/utils.py b/tests/models/utils.py index 0d5e304d8446..51d57129d9d2 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,11 +1,18 @@ -from typing import Dict, List, Tuple +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from vllm.sequence import SampleLogprobs TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], - outputs_1_lst: List[TokensText], name_0: str, - name_1: str): +def check_outputs_equal( + *, + outputs_0_lst: Sequence[TokensText], + outputs_1_lst: Sequence[TokensText], + name_0: str, + name_1: str, +): """ Compare the two sequences generated by different models, which should be equal. @@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 - assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + # The text and token outputs should exactly match + fail_msg = (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + assert output_str_0 == output_str_1, fail_msg + assert output_ids_0 == output_ids_1, fail_msg -TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] +TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, + float]], + SampleLogprobs]]] -def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, - name_1: str): +def check_logprobs_close( + *, + outputs_0_lst: Sequence[TokensTextLogprobs], + outputs_1_lst: Sequence[TokensTextLogprobs], + name_0: str, + name_1: str, + warn_on_mismatch: bool = True, +): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], output_ids_0, output_str_0, logprobs_0 = outputs_0 output_ids_1, output_str_1, logprobs_1 = outputs_1 + if logprobs_0 is None: + logprobs_0 = [None] * len(output_ids_0) + if logprobs_1 is None: + logprobs_1 = [None] * len(output_ids_1) + # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): # If generated tokens don't match, then if output_id_0 != output_id_1: + logprobs_elem_0 = logprobs_0[idx] + logprobs_elem_1 = logprobs_1[idx] + # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], ( - f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], ( + fail_msg = ( f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}" + f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}") + + assert logprobs_elem_0 is not None, fail_msg + assert logprobs_elem_1 is not None, fail_msg + assert output_id_0 in logprobs_elem_1, fail_msg + assert output_id_1 in logprobs_elem_0, fail_msg + + if warn_on_mismatch: + with warnings.catch_warnings(): + # This ensures that repeated warnings are shown + # in the output, not just the first occurrence + warnings.simplefilter("always") + + warnings.warn(fail_msg, stacklevel=2) # Break out since sequences will now diverge. break + else: + if output_str_0 != output_str_1 and warn_on_mismatch: + # The token outputs exactly match, + # so the text outputs should exactly match as well + fail_msg = (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + with warnings.catch_warnings(): + # This ensures that repeated warnings are shown + # in the output, not just the first occurrence + warnings.simplefilter("always") + + warnings.warn(fail_msg, stacklevel=2) diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index bdbbd9abfc5c..321566ad53a5 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -4,12 +4,12 @@ from vllm.config import ModelConfig from vllm.multimodal import MULTIMODAL_REGISTRY - -from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE +from vllm.multimodal.utils import rescale_image_size @pytest.mark.parametrize("dtype", ["half", "float"]) -def test_clip_image_processor(image_assets, dtype): +@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) +def test_clip_image_processor(image_assets, dtype, size_factor): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) @@ -26,13 +26,15 @@ def test_clip_image_processor(image_assets, dtype): ) for asset in image_assets: + image = rescale_image_size(asset.pil_image, size_factor) + hf_result = hf_processor.preprocess( - asset.pil_image, + image, return_tensors="pt", - ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + ) vllm_result = MULTIMODAL_REGISTRY.map_input( model_config, - {"image": asset.pil_image}, + {"image": image}, ) assert hf_result.keys() == vllm_result.keys() @@ -44,12 +46,10 @@ def test_clip_image_processor(image_assets, dtype): assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" -@pytest.mark.xfail( - reason="Inconsistent image processor being used due to lack " - "of support for dynamic image token replacement") @pytest.mark.parametrize("dtype", ["half", "float"]) -def test_llava_next_image_processor(image_assets, dtype): - MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" +@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) +def test_llava_next_image_processor(image_assets, dtype, size_factor): + MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf" hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) assert isinstance(hf_processor, LlavaNextImageProcessor) @@ -65,13 +65,15 @@ def test_llava_next_image_processor(image_assets, dtype): ) for asset in image_assets: + image = rescale_image_size(asset.pil_image, size_factor) + hf_result = hf_processor.preprocess( - asset.pil_image, + image, return_tensors="pt", - ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + ) vllm_result = MULTIMODAL_REGISTRY.map_input( model_config, - {"image": asset.pil_image}, + {"image": image}, ) assert hf_result.keys() == vllm_result.keys() @@ -81,36 +83,3 @@ def test_llava_next_image_processor(image_assets, dtype): assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - -@pytest.mark.xfail( - reason="Example image pixels were not processed using HuggingFace") -@pytest.mark.parametrize("dtype", ["float"]) -def test_image_pixel_types(image_assets, dtype): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - - model_config = ModelConfig( - model=MODEL_NAME, - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - ) - for asset in image_assets: - image_result = MULTIMODAL_REGISTRY.map_input( - model_config, - {"image": asset.pil_image}, - ) - tensor_result = MULTIMODAL_REGISTRY.map_input( - model_config, - {"image": asset.pil_image}, - ) - - assert image_result.keys() == tensor_result.keys() - for key, image_arr in image_result.items(): - tensor_arr: np.ndarray = tensor_result[key].numpy() - - assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" - assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 5a6395ac9e42..10cabdadb1dc 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -5,10 +5,9 @@ import numpy as np import pytest -import pytest_asyncio from PIL import Image -from vllm.multimodal.utils import ImageFetchAiohttp +from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -19,12 +18,9 @@ ] -@pytest_asyncio.fixture(scope="session") -async def url_images() -> Dict[str, Image.Image]: - return { - image_url: await ImageFetchAiohttp.fetch_image(image_url) - for image_url in TEST_IMAGE_URLS - } +@pytest.fixture(scope="module") +def url_images() -> Dict[str, Image.Image]: + return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS} def get_supported_suffixes() -> Tuple[str, ...]: @@ -41,7 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() -@pytest.mark.asyncio +@pytest.mark.asyncio(scope="module") +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_fetch_image_http(image_url: str): + image_sync = fetch_image(image_url) + image_async = await ImageFetchAiohttp.fetch_image(image_url) + assert _image_equals(image_sync, image_async) + + +@pytest.mark.asyncio(scope="module") @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: Dict[str, Image.Image], @@ -68,8 +72,11 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], base64_image = base64.b64encode(f.read()).decode("utf-8") data_url = f"data:{mime_type};base64,{base64_image}" - data_image = await ImageFetchAiohttp.fetch_image(data_url) + data_image_sync = fetch_image(data_url) if _image_equals(url_image, Image.open(f)): - assert _image_equals(url_image, data_image) + assert _image_equals(url_image, data_image_sync) else: pass # Lossy format; only check that image can be opened + + data_image_async = await ImageFetchAiohttp.fetch_image(data_url) + assert _image_equals(data_image_sync, data_image_async) diff --git a/vllm/config.py b/vllm/config.py index 8c449323f7a1..de8e119c9498 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5,7 +5,7 @@ Union) import torch -from transformers import PretrainedConfig, PreTrainedTokenizerBase +from transformers import PretrainedConfig import vllm.envs as envs from vllm.logger import init_logger @@ -1303,16 +1303,6 @@ class VisionLanguageConfig: image_input_shape: tuple image_feature_size: int - #TODO(ywang96): make this a cached property once we refactor the - # VisionLanguageConfig class. - def get_image_token_text( - self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: - """Get the image token placeholder text to be inserted into the - text prompt and the string representation of the image token id. - """ - image_token_str = tokenizer.decode(self.image_token_id) - return image_token_str * self.image_feature_size, image_token_str - def as_cli_args_dict(self) -> Dict[str, Any]: """Flatten vision language config to pure args. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e5b6b7f573a2..57ad7bdd3105 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,7 @@ import codecs import time from dataclasses import dataclass, field +from functools import cached_property from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional) from typing import Sequence as GenericSequence @@ -10,7 +11,7 @@ from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionContentPartParam, ChatCompletionLogProb, @@ -27,8 +28,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import (async_get_and_parse_image, - get_full_image_text_prompt) +from vllm.multimodal.utils import async_get_and_parse_image from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -97,6 +97,36 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.warning( "No chat template provided. Chat API will not work.") + @cached_property + def image_token_str(self) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + model_type = self.model_config.hf_config.model_type + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return "<|image_1|>" + if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", + "paligemma"): + # These models do not use image tokens in the prompt + return None + + # The default behaviour assumes that the image token is + # available to the tokenizer. + # (Suitable for LLaVA, Idefics2, DeepSeek-VL) + vlm_config = self.model_config.multimodal_config + if vlm_config is None: + raise ValueError( + "'image_url' input is not supported as the loaded " + "model is not multimodal.") + + image_token_id = vlm_config.image_token_id + if vlm_config.image_token_id is None: + raise ValueError( + "'image_url' input is not supported as the loaded " + "model does not specify an image token.") + + return self.tokenizer.decode(image_token_id) + def _parse_chat_message_content_parts( self, role: str, @@ -105,21 +135,26 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] - vlm_config: Optional[VisionLanguageConfig] = getattr( - self.engine.engine, "vision_language_config", None) - model_config = getattr(self.engine.engine, "model_config", None) - for part in parts: part_type = part["type"] if part_type == "text": text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) elif part_type == "image_url": - if vlm_config is None: - raise ValueError( - "'image_url' input is not supported as the loaded " - "model is not multimodal.") - assert self.tokenizer is not None + if len(mm_futures) > 0: + raise NotImplementedError( + "Multiple 'image_url' input is currently not supported." + ) + + image_token_str = self.image_token_str + if image_token_str is not None: + if any(image_token_str in text for text in texts): + logger.warning( + "Detected image token string in the text prompt. " + "Skipping prompt formatting.") + else: + texts.append(image_token_str) + image_url = cast(ChatCompletionContentPartImageParam, part)["image_url"] @@ -128,43 +163,13 @@ def _parse_chat_message_content_parts( "'image_url.detail' is currently not supported and " "will be ignored.") - mm_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(mm_future) - + image_future = async_get_and_parse_image(image_url["url"]) + mm_futures.append(image_future) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - - if vlm_config is not None and len(mm_futures): - - assert len( - mm_futures - ) == 1, "Multiple 'image_url' input is currently not supported." - (image_token_prompt, - image_token_str) = vlm_config.get_image_token_text(self.tokenizer) - - # NOTE: If image token string (e.g, ) is already present - # in the text prompt, we assume it follows the same format required - # by the engine. - if image_token_str in text_prompt: - logger.warning( - "Detected image token string in the text prompt. " - "Skipping prompt formatting.") - messages = [ - ConversationMessage(role=role, content=text_prompt) - ] - - else: - full_prompt = get_full_image_text_prompt( - image_prompt=image_token_prompt, - text_prompt=text_prompt, - config=model_config) - messages = [ - ConversationMessage(role=role, content=full_prompt) - ] - else: - messages = [ConversationMessage(role=role, content=text_prompt)] + messages = [ConversationMessage(role=role, content=text_prompt)] return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) @@ -267,7 +272,7 @@ async def create_chat_completion( "prompt": prompt_text, "prompt_token_ids": prompt_ids, } - if mm_data is not None: + if mm_data: inputs["multi_modal_data"] = mm_data is_tracing_enabled = await self.engine.is_tracing_enabled() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 84e4127725bb..8d281c51f02b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -36,6 +36,7 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, super().__init__() self.engine = engine + self.model_config = model_config self.max_model_len = model_config.max_model_len # A separate tokenizer to map token IDs to strings. diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 3e28733383cb..936909eb33f6 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -140,7 +140,8 @@ def dummy_data_for_profiling(self, model_config: "ModelConfig", The model is identified by ``model_config``. - TODO: Add guide [ref: PR #5276] + See also: + :ref:`adding_a_new_multimodal_model` """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 5212e2808fb3..4533e8cbdb41 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -8,10 +8,14 @@ from transformers import CLIPVisionConfig from transformers.models.clip.modeling_clip import CLIPAttention +from vllm.config import ModelConfig +from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) from vllm.sequence import SequenceData @@ -64,6 +68,39 @@ def dummy_image_for_clip( return {"image": image} +def input_processor_for_clip( + model_config: ModelConfig, + hf_config: CLIPVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa class CLIPVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index bbec4dbd897c..2588d8b06551 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -6,7 +6,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -20,8 +20,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, + input_processor_for_clip) from .interfaces import SupportsVision +from .utils import merge_vision_embeddings _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", @@ -51,28 +53,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -def merge_vision_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - image_token_id: int) -> torch.Tensor: - """In place merges in vision_embeddings with inputs_embeds.""" - mask = (input_ids == image_token_id) - - image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1] - if mask.sum() != image_feature_size: - raise ValueError(f"image_feature_size should be {image_feature_size}, " - f"but found: {mask.sum()}") - - inputs_embeds[mask] = vision_embeddings.view(image_feature_size, - vision_embeddings.shape[-1]) - - return inputs_embeds - - class LlavaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: `(batch_size, num_channels, height, width)`""" LlavaImageInputs = LlavaImagePixelInputs @@ -96,8 +80,30 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int): raise NotImplementedError(msg) +def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return input_processor_for_clip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + @MULTIMODAL_REGISTRY.register_image_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsVision): def __init__(self, @@ -112,7 +118,6 @@ def __init__(self, # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = CLIPVisionModel(config.vision_config) - self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index f67598c4004b..92604cdf3760 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch import torch.nn as nn @@ -10,7 +10,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -21,13 +21,14 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, - get_clip_patch_grid_length) + get_clip_patch_grid_length, input_processor_for_clip) from .interfaces import SupportsVision -from .llava import LlavaMultiModalProjector, merge_vision_embeddings +from .llava import LlavaMultiModalProjector +from .utils import merge_vision_embeddings logger = init_logger(__name__) @@ -39,16 +40,27 @@ class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor - """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" + data: BatchedTensors + """ + Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different for each batch. + """ image_sizes: NotRequired[torch.Tensor] - """Shape: (batch_size, 2)""" + """ + Shape: `(batch_size, 2)` + + This should be in `(height, width)` format. + """ LlavaNextImageInputs = LlavaNextImagePixelInputs +# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91 +# NOTE: new_height and new_width are further incremented to properly invert the +# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133 def _get_llava_next_num_unpadded_features( height: int, width: int, @@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features( num_patch_height: int, num_patch_width: int, ) -> Tuple[int, int]: - # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111 current_height = npatches * num_patch_height current_width = npatches * num_patch_width @@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features( current_aspect_ratio: float = current_width / current_height if aspect_ratio > current_aspect_ratio: new_height = (height * current_width) // width + if new_height % 2 == 1: + new_height += 1 current_height = new_height else: new_width = (width * current_height) // height + if new_width % 2 == 1: + new_width += 1 current_width = new_width unpadded_features = current_height * current_width @@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features( return (unpadded_features, newline_features) -def _get_llava_next_image_feature_size( +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111 +def get_llava_next_image_feature_size( hf_config: LlavaNextConfig, *, input_height: int, @@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size( ) base_feature_size = num_patches * num_patches - num_patch_height, num_patch_width = get_anyres_image_grid_shape( + # Note: We follow the "wrong" width/height order + # [ref: PR huggingface/transformers#31588] + num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, @@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size( def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): - multimodal_config = ctx.get_multimodal_config() hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config - #TODO: change the logic for dummy data to support dynamic shape - _, _, dummy_height, dummy_width = multimodal_config.image_input_shape - image_feature_size = _get_llava_next_image_feature_size( - hf_config, input_height=dummy_height, input_width=dummy_width) + # Result in the max possible feature size (2x2 grid of 336x336px tiles) + dummy_height = dummy_width = 448 + image_feature_size = get_llava_next_image_feature_size( + hf_config, + input_height=dummy_height, + input_width=dummy_width, + ) if isinstance(vision_config, CLIPVisionConfig): seq_data = dummy_seq_data_for_clip( @@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): raise NotImplementedError(msg) -def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]: +def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs - if isinstance(image, Image.Image): + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaNextConfig) + vision_config = hf_config.vision_config - # Temporary patch before dynamic number of image tokens is supported - _, _, h, w = ctx.get_multimodal_config().image_input_shape - if (w, h) != (image.width, image.height): - logger.warning( - "Dynamic image shape is currently not supported. " - "Resizing input image to (%d, %d).", w, h) + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + width, height = image_data.size + + image_feature_size = get_llava_next_image_feature_size( + hf_config, + input_height=height, + input_width=width, + ) + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") - image = image.resize((w, h)) + vision_config = hf_config.vision_config - return MULTIMODAL_REGISTRY._get_plugin("image") \ - ._default_input_mapper(ctx, image) + if isinstance(vision_config, CLIPVisionConfig): + return input_processor_for_clip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) - raise TypeError(f"Invalid type for 'image': {type(image)}") + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper) +@MULTIMODAL_REGISTRY.register_image_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): def __init__(self, @@ -172,8 +212,8 @@ def __init__(self, self.config = config self.vlm_config = vlm_config + # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = CLIPVisionModel(config=config.vision_config) - self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -196,24 +236,6 @@ def __init__(self, self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) - def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: - _, num_channels, _, _ = self.vlm_config.image_input_shape - - # Note that this is different from that of vLLM vision_language_config - # since the image is resized by the HuggingFace preprocessor - height = width = self.config.vision_config.image_size - - if list(data.shape[2:]) != [num_channels, height, width]: - raise ValueError( - f"The expected image tensor shape is batch dimension plus " - f"num_patches plus {[num_channels, height, width]}. " - f"You supplied {data.shape}. " - f"If you are using vLLM's entrypoint, make sure your " - f"supplied image input is consistent with " - f"image_input_shape in engine args.") - - return data - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: if list(data.shape[1:]) != [2]: raise ValueError( @@ -223,14 +245,14 @@ def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: return data def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) - if pixel_values is None or image_sizes is None: + if pixel_values is None: return None - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") @@ -240,7 +262,7 @@ def _parse_and_validate_image_input( return LlavaNextImagePixelInputs( type="pixel_values", - data=self._validate_image_pixels(pixel_values), + data=pixel_values, image_sizes=self._validate_image_sizes(image_sizes), ) @@ -267,15 +289,14 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, strategy=self.config.vision_feature_select_strategy, ) + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py def _merge_image_patch_embeddings(self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str) -> torch.Tensor: - # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): - orig_width, orig_height = image_size height = width = self.config.vision_config.image_size \ // self.config.vision_config.patch_size @@ -289,13 +310,15 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, other_patch_embeds = patch_embeddings[1:] # image_aspect_ratio == "anyres" + # Note: We follow the "wrong" width/height order + # [ref: PR huggingface/transformers#31588] num_patch_width, num_patch_height = get_anyres_image_grid_shape( - (orig_width, orig_height), + image_size, self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) other_patch_embeds = other_patch_embeds \ - .view(num_patch_width, num_patch_height, height, width, -1) + .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: other_patch_embeds = other_patch_embeds \ @@ -333,44 +356,53 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, raise ValueError(f"Unexpected patch merge strategy: {strategy}") def _process_image_pixels( - self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor: + self, + inputs: LlavaNextImagePixelInputs, + ) -> BatchedTensors: assert self.vision_tower is not None pixel_values = inputs["data"] - b, num_patches, c, h, w = pixel_values.shape - stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) - return stacked_image_features.view(b, num_patches, - *stacked_image_features.shape[-2:]) + return [ + self.multi_modal_projector(image_features) for image_features in + torch.split(stacked_image_features, num_patches_per_batch) + ] def _process_image_input( - self, image_input: LlavaNextImageInputs) -> torch.Tensor: - assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) - - patch_embeddings = self.multi_modal_projector(image_features) + self, image_input: LlavaNextImageInputs) -> BatchedTensors: + patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") if image_sizes is None: - batch_size = image_input["data"].shape[0] + batch_size = len(image_input["data"]) vision_config = self.config.vision_config - default_width = default_height = vision_config.image_size - image_sizes = torch.as_tensor([[default_width, default_height] + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] for _ in range(batch_size)]) - merged_patch_embeddings = [ + return [ self._merge_image_patch_embeddings(image_sizes[i], - patch_features, + patch_features_batch, strategy="spatial_unpad") - for i, patch_features in enumerate(patch_embeddings) + for i, patch_features_batch in enumerate(patch_embeddings) ] - return torch.stack(merged_patch_embeddings, dim=0) - def forward( self, input_ids: torch.Tensor, @@ -404,8 +436,8 @@ def forward( input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each grid patch for each input image. - Expects a batch with shape `[1, num_patches, 3, 336, 336]`. - image_sizes: The original `(width, height)` for each input image. + Expects a batch with shape `[1, num_patches, 3, h, w]`. + image_sizes: The original `(height, width)` for each input image. Expects a batch with shape `[1, 2]`. See also: diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d73a42026bc3..3d247c9ed2e6 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -13,7 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict +import re +from functools import lru_cache +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import numpy as np import torch @@ -22,8 +24,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, VisionLanguageConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -34,10 +36,12 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, + input_processor_for_clip) from .interfaces import SupportsVision logger = init_logger(__name__) @@ -251,50 +255,22 @@ def forward(self, input_ids: torch.LongTensor, class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor - """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" - - image_sizes: torch.Tensor - """Shape: (batch_size, 2)""" - - -def _get_phi3v_image_feature_size( - *, - input_height: int, - input_width: int, -) -> int: - h, w = input_height, input_width - - # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178 - return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12 + data: BatchedTensors + """ + Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + Note that `num_patches` may be different for each batch. + """ -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): - multimodal_config = ctx.get_multimodal_config() - - #TODO: change the logic for dummy data to support dynamic shape - _, _, dummy_height, dummy_width = multimodal_config.image_input_shape - image_feature_size = _get_phi3v_image_feature_size( - input_height=dummy_height, - input_width=dummy_width, - ) - - seq_data = dummy_seq_data_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - seq_len, - image_token_id=32044, - image_feature_size_override=image_feature_size, - ) - mm_data = dummy_image_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - image_width_override=dummy_width, - image_height_override=dummy_height, - ) + image_sizes: torch.Tensor + """ + Shape: `(batch_size, 2)` - return seq_data, mm_data + This should be in `(height, width)` format. + """ -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): target_height = int(np.ceil(height / padding_unit) * padding_unit) top_padding = int((target_height - height) / 2) @@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): return padded_width, padded_height -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): transposed = False if width < height: @@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): return padded_width, padded_height -def _image_processor(ctx: InputContext, - image: object) -> Dict[str, torch.Tensor]: +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 +def get_phi3v_image_feature_size( + hf_config: PretrainedConfig, + *, + input_height: int, + input_width: int, +) -> int: + num_crops = getattr(hf_config, "num_crops", 16) + new_width, new_height = _calc_hd_transform_size(width=input_width, + height=input_height, + hd_num=num_crops) - if isinstance(image, Image.Image): - # Temporary patch before dynamic number of image tokens is supported - _, _, h, w = ctx.get_multimodal_config().image_input_shape - if (w, h) != _calc_hd_transform_size(width=image.width, - height=image.height): - logger.warning( - "Dynamic image shape is currently not supported. " - "Resizing input image to (%d, %d).", w, h) + return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \ + + (new_height // 336 + 1) * 12 - image = image.resize((w, h)) - return MULTIMODAL_REGISTRY._get_plugin("image") \ - ._default_input_mapper(ctx, image) - raise TypeError(f"Invalid type for 'image': {type(image)}") +def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): + # Result in the max possible feature size (h:w = 16:1) + dummy_height, dummy_width = 8000, 50 + image_feature_size = get_phi3v_image_feature_size( + ctx.get_hf_config(PretrainedConfig), + input_height=dummy_height, + input_width=dummy_width, + ) + + seq_data = dummy_seq_data_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + seq_len, + image_token_id=32044, + image_feature_size_override=image_feature_size, + ) + mm_data = dummy_image_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) + + return seq_data, mm_data + +# Reserve this function to also handle placeholders for additional images +# [ref: PR #5820] +@lru_cache +def _get_image_placeholder_token_ids(model_config: ModelConfig, + idx: int) -> List[int]: + assert idx > 0 -@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor) + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + # We need to get the token for "<", not "▁<" + # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json + a_token_id, = tokenizer.encode("a", add_special_tokens=False) + a_token_id_, *image_placeholder_token_ids = tokenizer.encode( + f"a<|image_{idx}|>", add_special_tokens=False) + assert a_token_id == a_token_id_ + + return image_placeholder_token_ids + + +def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + multimodal_config = ctx.get_multimodal_config() + hf_config = ctx.get_hf_config(PretrainedConfig) + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + w, h = image_data.size + w, h = _calc_hd_transform_size(width=w, height=h) + + image_feature_size = get_phi3v_image_feature_size(hf_config, + input_width=w, + input_height=h) + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + prompt = llm_inputs.get("prompt") + if prompt is None: + new_prompt = None + else: + if prompt.count("<|image|>") > 0: + logger.warning("Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating <|image|> tokens.") + elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1: + logger.warning("Multiple image input is not supported yet, " + "so any extra image tokens will be treated " + "as plain text.") + + new_prompt = prompt + + prompt_token_ids = llm_inputs["prompt_token_ids"] + image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1) + + new_token_ids: List[int] = [] + for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1): + if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids: + new_token_ids.append(multimodal_config.image_token_id) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):]) + break + else: + new_token_ids.append(prompt_token_ids[i]) + + # NOTE: Create a defensive copy of the original inputs + llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + return input_processor_for_clip( + model_config, + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + llm_inputs, + image_token_id=multimodal_config.image_token_id, + image_feature_size_override=image_feature_size, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) +@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsVision): def __init__(self, @@ -363,6 +445,8 @@ def __init__(self, self.vlm_config = vlm_config self.model = LlamaModel(config, cache_config, quant_config) + + # TODO: Optionally initializes this for supporting embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( vlm_config, config, self.model.embed_tokens) self.lm_head = ParallelLMHead(config.vocab_size, @@ -376,12 +460,20 @@ def _parse_and_validate_image_input( pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) - if pixel_values is not None and image_sizes is not None: - return Phi3VImagePixelInputs(type="pixel_values", - data=pixel_values, - image_sizes=image_sizes) + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, torch.Tensor): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") - return None + return Phi3VImagePixelInputs(type="pixel_values", + data=pixel_values, + image_sizes=image_sizes) def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py new file mode 100644 index 000000000000..ef2562b073e6 --- /dev/null +++ b/vllm/model_executor/models/utils.py @@ -0,0 +1,41 @@ +import torch + +from vllm.multimodal import BatchedTensors + + +def merge_vision_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: BatchedTensors, + image_token_id: int) -> torch.Tensor: + """ + Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions + in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`. + + Note: + This updates `inputs_embeds` in place. + """ + mask = (input_ids == image_token_id) + num_expected_tokens = mask.sum() + + if isinstance(vision_embeddings, torch.Tensor): + batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape + total_tokens = batch_size * batch_tokens + if num_expected_tokens != total_tokens: + expr = f"{batch_size} x {batch_tokens}" + raise ValueError( + f"Attempted to assign {expr} = {total_tokens} " + f"image tokens to {num_expected_tokens} placeholders") + + inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim) + else: + size_per_batch = [t.shape[0] for t in vision_embeddings] + total_tokens = sum(size_per_batch) + if num_expected_tokens != total_tokens: + expr = ' + '.join(map(str, size_per_batch)) + raise ValueError( + f"Attempted to assign {expr} = {total_tokens} " + f"image tokens to {num_expected_tokens} placeholders") + + inputs_embeds[mask] = torch.cat(vision_embeddings) + + return inputs_embeds diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 256eadd2d7df..b6d930659a8c 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,4 +1,5 @@ -from .base import MultiModalDataDict, MultiModalPlugin +from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs, + MultiModalPlugin) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -11,8 +12,10 @@ """ __all__ = [ + "BatchedTensors", + "MultiModalDataDict", + "MultiModalInputs", "MultiModalPlugin", "MULTIMODAL_REGISTRY", "MultiModalRegistry", - "MultiModalDataDict", ] diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 558cd1175298..e7b45649d728 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,23 +1,90 @@ +import sys from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type, - TypedDict, TypeVar, Union) +from collections import UserDict, defaultdict +from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict, + TypeVar, Union) + +import torch +import torch.types +from PIL import Image +from torch import nn from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -if TYPE_CHECKING: - import torch - from PIL import Image - from torch import nn - logger = init_logger(__name__) -N = TypeVar("N", bound=Type["nn.Module"]) +BatchedTensors = Union[torch.Tensor, List[torch.Tensor]] +""" +If each input tensor in the batch has the same size, this is a single batched +tensor; otherwise, this is a list of tensors with one element per batch. +""" + +if sys.version_info < (3, 9): + # UserDict cannot be subscripted + class _MultiModalInputsBase(UserDict): + pass +else: + + class _MultiModalInputsBase(UserDict[str, torch.Tensor]): + pass + + +class MultiModalInputs(_MultiModalInputsBase): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + """ + + @staticmethod + def try_concat( + tensors: List[torch.Tensor], + *, + device: torch.types.Device, + ) -> BatchedTensors: + # Avoid initializing CUDA too early + import torch + + unbatched_shape = tensors[0].shape[1:] + + for tensor in tensors: + if tensor.shape[1:] != unbatched_shape: + return [ + tensor.squeeze(0).to(device=device) for tensor in tensors + ] + + return torch.cat(tensors, dim=0).to(device=device) + + @staticmethod + def batch( + inputs_list: List["MultiModalInputs"], + device: torch.types.Device, + ) -> Dict[str, BatchedTensors]: + """Batch multiple inputs together into a dictionary.""" + if len(inputs_list) == 0: + return {} + + keys = inputs_list[0].keys() + + item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list) + + for inputs in inputs_list: + if inputs.keys() != keys: + msg = f"Inputs do not share the same keys ({keys})" + raise ValueError(msg) + + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalInputs.try_concat(item_list, device=device) + for k, item_list in item_lists.items() + } class MultiModalDataBuiltins(TypedDict, total=False): - image: "Image.Image" + image: Image.Image MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] @@ -29,12 +96,13 @@ class MultiModalDataBuiltins(TypedDict, total=False): the corresponding plugin with the same modality key is applied. """ -MultiModalInputMapper = Callable[[InputContext, object], Dict[str, - "torch.Tensor"]] +MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] """Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers.""" +N = TypeVar("N", bound=Type[nn.Module]) + class MultiModalPlugin(ABC): """ @@ -48,8 +116,7 @@ class MultiModalPlugin(ABC): """ def __init__(self) -> None: - self._input_mappers: Dict[Type["nn.Module"], - MultiModalInputMapper] = {} + self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} @abstractmethod def get_data_key(self) -> str: @@ -60,7 +127,7 @@ def get_data_key(self) -> str: @abstractmethod def _default_input_mapper(self, ctx: InputContext, - data: object) -> Dict[str, "torch.Tensor"]: + data: object) -> MultiModalInputs: """Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. @@ -80,6 +147,7 @@ def register_input_mapper( See also: :ref:`input_processing_pipeline` + :ref:`adding_a_new_multimodal_model` """ def wrapper(model_cls: N) -> N: @@ -97,7 +165,7 @@ def wrapper(model_cls: N) -> N: return wrapper def map_input(self, model_config: ModelConfig, - data: object) -> Dict[str, "torch.Tensor"]: + data: object) -> MultiModalInputs: """ Apply an input mapper to a data passed to the model, transforming the data into a dictionary of model inputs. @@ -106,7 +174,8 @@ def map_input(self, model_config: ModelConfig, The model is identified by ``model_config``. - TODO: Add guide [ref: PR #5276] + See also: + :ref:`adding_a_new_multimodal_model` """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index a0b4206bf2ee..dfef33121cbf 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,19 +1,102 @@ from functools import lru_cache -from typing import Dict +from typing import List, Optional, Tuple, TypeVar import torch from PIL import Image +from transformers import PreTrainedTokenizerBase from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.tokenizer import get_tokenizer -from .base import MultiModalPlugin +from .base import MultiModalInputs, MultiModalPlugin logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) +cached_get_tokenizer = lru_cache(get_tokenizer) + +# Utilities for image input processors +_T = TypeVar("_T", str, int) + + +def repeat_and_pad_token( + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, +) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + +def repeat_and_pad_image_tokens( + tokenizer: PreTrainedTokenizerBase, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + image_token_id: int, + repeat_count: int = 1, + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, +) -> Tuple[Optional[str], List[int]]: + if prompt is None: + new_prompt = None + else: + image_token_str = tokenizer.decode(image_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + repeat_and_pad_token( + image_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + + image_token_count = prompt.count(image_token_str) + # This is an arbitrary number to distinguish between the two cases + if image_token_count > 16: + logger.warning( + "Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating %s tokens.", image_token_str) + elif image_token_count > 1: + logger.warning("Multiple image input is not supported yet, " + "so any extra image tokens will be treated " + "as plain text.") + + # The image tokens are removed to be consistent with HuggingFace + new_prompt = prompt.replace(image_token_str, replacement_str, 1) + + new_token_ids: List[int] = [] + for i, token in enumerate(prompt_token_ids): + if token == image_token_id: + replacement_ids = repeat_and_pad_token( + image_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids class ImagePlugin(MultiModalPlugin): @@ -27,7 +110,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig): trust_remote_code=model_config.trust_remote_code) def _default_input_mapper(self, ctx: InputContext, - data: object) -> Dict[str, torch.Tensor]: + data: object) -> MultiModalInputs: model_config = ctx.model_config if isinstance(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) @@ -35,10 +118,15 @@ def _default_input_mapper(self, ctx: InputContext, raise RuntimeError("No HuggingFace processor is available" "to process the image object") try: - return image_processor.preprocess(data, return_tensors="pt") \ - .to(model_config.dtype).data + batch_data = image_processor \ + .preprocess(data, return_tensors="pt") \ + .data except Exception: logger.error("Failed to process image (%s)", data) raise - raise TypeError(f"Invalid type for 'image': {type(data)}") + return MultiModalInputs(batch_data) + elif isinstance(data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + + raise TypeError(f"Invalid image type: {type(data)}") diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index a09a80f89f4b..f17b04149ede 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,18 +1,17 @@ import functools -from typing import Optional, Sequence, Type, TypeVar +from typing import Dict, Optional, Sequence -from torch import nn +import torch from vllm.config import ModelConfig from vllm.logger import init_logger -from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin +from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, + MultiModalPlugin) from .image import ImagePlugin logger = init_logger(__name__) -N = TypeVar("N", bound=Type[nn.Module]) - class MultiModalRegistry: """ @@ -61,7 +60,7 @@ def register_image_input_mapper( return self.register_input_mapper("image", mapper) def _process_input(self, key: str, value: object, - model_config: ModelConfig): + model_config: ModelConfig) -> MultiModalInputs: plugin = self._plugins.get(key) if plugin: return plugin.map_input(model_config, value) @@ -93,16 +92,28 @@ def register_image_input(self, """ return self.register_input_mapper("image", mapper) - def map_input(self, model_config: ModelConfig, data: MultiModalDataDict): + def map_input(self, model_config: ModelConfig, + data: MultiModalDataDict) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. See :meth:`MultiModalPlugin.map_input` for more details. """ - result_list = [ - self._process_input(k, v, model_config) for k, v in data.items() - ] - return {k: v for d in result_list for k, v in d.items()} + merged_dict: Dict[str, torch.Tensor] = {} + + for data_key, data_value in data.items(): + input_dict = self._process_input(data_key, data_value, + model_config) + + for input_key, input_tensor in input_dict.items(): + if input_key in merged_dict: + raise ValueError(f"The input mappers (keys={set(data)}) " + f"resulted in a conflicting keyword " + f"argument to `forward()`: {input_key}") + + merged_dict[input_key] = input_tensor + + return MultiModalInputs(merged_dict) def create_input_mapper(self, model_config: ModelConfig): """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 321b51e5a883..e55b8bbfdeaa 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -4,11 +4,56 @@ from urllib.parse import urlparse import aiohttp +import requests from PIL import Image -from vllm.config import ModelConfig from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT from vllm.multimodal.base import MultiModalDataDict +from vllm.version import __version__ as VLLM_VERSION + + +def _validate_remote_url(url: str, *, name: str): + parsed_url = urlparse(url) + if parsed_url.scheme not in ["http", "https"]: + raise ValueError(f"Invalid '{name}': A valid '{name}' " + "must have scheme 'http' or 'https'.") + + +def _get_request_headers(): + return {"User-Agent": f"vLLM/{VLLM_VERSION}"} + + +def _load_image_from_bytes(b: bytes): + image = Image.open(BytesIO(b)) + image.load() + return image + + +def _load_image_from_data_url(image_url: str): + # Only split once and assume the second part is the base64 encoded image + _, image_base64 = image_url.split(",", 1) + return load_image_from_base64(image_base64) + + +def fetch_image(image_url: str) -> Image.Image: + """Load PIL image from a url or base64 encoded openai GPT4V format""" + if image_url.startswith('http'): + _validate_remote_url(image_url, name="image_url") + + headers = _get_request_headers() + + with requests.get(url=image_url, headers=headers) as response: + response.raise_for_status() + image_raw = response.content + image = _load_image_from_bytes(image_raw) + + elif image_url.startswith('data:image'): + image = _load_image_from_data_url(image_url) + else: + raise ValueError("Invalid 'image_url': A valid 'image_url' must start " + "with either 'data:image' or 'http'.") + + return image class ImageFetchAiohttp: @@ -29,34 +74,31 @@ async def fetch_image(cls, image_url: str) -> Image.Image: """Load PIL image from a url or base64 encoded openai GPT4V format""" if image_url.startswith('http'): - parsed_url = urlparse(image_url) - if parsed_url.scheme not in ["http", "https"]: - raise ValueError("Invalid 'image_url': A valid 'image_url' " - "must have scheme 'http' or 'https'.") - # Avoid circular import - from vllm import __version__ as VLLM_VERSION + _validate_remote_url(image_url, name="image_url") client = cls.get_aiohttp_client() - headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"} + headers = _get_request_headers() async with client.get(url=image_url, headers=headers) as response: response.raise_for_status() image_raw = await response.read() - image = Image.open(BytesIO(image_raw)) + image = _load_image_from_bytes(image_raw) - # Only split once and assume the second part is the base64 encoded image elif image_url.startswith('data:image'): - image = load_image_from_base64(image_url.split(',', 1)[1]) - + image = _load_image_from_data_url(image_url) else: raise ValueError( "Invalid 'image_url': A valid 'image_url' must start " "with either 'data:image' or 'http'.") - image.load() return image +async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = await ImageFetchAiohttp.fetch_image(image_url) + return {"image": image} + + def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: """Encode a pillow image to base64 format.""" @@ -69,26 +111,11 @@ def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: """Load image from base64 format.""" - return Image.open(BytesIO(base64.b64decode(image))) + return _load_image_from_bytes(base64.b64decode(image)) -# TODO(ywang96): move this to a model registry for preprocessing vision -# language prompts based on the model type. -def get_full_image_text_prompt(image_prompt: str, text_prompt: str, - config: ModelConfig) -> str: - """Combine image and text prompts for vision language model depending on - the model architecture.""" - - if config.hf_config.model_type in ("llava", "llava_next"): - full_prompt = f"{image_prompt}\n{text_prompt}" - elif config.hf_config.model_type == 'phi3_v': - full_prompt = f"{image_prompt}\n{text_prompt}" - else: - raise ValueError( - f"Unsupported model type: {config.hf_config.model_type}") - return full_prompt - - -async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = await ImageFetchAiohttp.fetch_image(image_url) - return {"image": image} +def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image: + """Rescale the dimensions of an image by a constant factor.""" + new_width = int(image.width * size_factor) + new_height = int(image.height * size_factor) + return image.resize((new_width, new_height)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7e08586cdfd9..d200115aa092 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -457,7 +457,7 @@ def prompt_token_ids(self) -> List[int]: return next(iter(self.seqs_dict.values())).prompt_token_ids @property - def multi_modal_data(self) -> Optional["MultiModalDataDict"]: + def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 354dcb526395..c7d9eabd06f0 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -1,9 +1,4 @@ -from transformers import AutoImageProcessor -from transformers.image_processing_utils import BaseImageProcessor - -from vllm.logger import init_logger - -logger = init_logger(__name__) +from typing import cast def get_image_processor( @@ -11,10 +6,15 @@ def get_image_processor( *args, trust_remote_code: bool = False, **kwargs, -) -> BaseImageProcessor: +): """Gets an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + try: - processor: BaseImageProcessor = AutoImageProcessor.from_pretrained( + processor = AutoImageProcessor.from_pretrained( processor_name, *args, trust_remote_code=trust_remote_code, @@ -34,4 +34,4 @@ def get_image_processor( else: raise e - return processor + return cast(BaseImageProcessor, processor) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fd6c2b8546df..d8397ac22a58 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,6 @@ -from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Type, Union) import torch from torch import nn @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad @@ -40,7 +41,7 @@ class CPUModelInput(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -132,15 +133,14 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[ - str, torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list) + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -162,10 +162,9 @@ def _prepare_prompt( input_positions.extend(list(range(computed_len, seq_len))) mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: + if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) + multi_modal_inputs_list.append(mm_kwargs) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -189,11 +188,6 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - multi_modal_kwargs = { - k: torch.cat(v, dim=0).to(self.device) - for k, v in multi_modal_kwargs_list.items() - } - num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, @@ -217,6 +211,10 @@ def _prepare_prompt( block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -367,10 +365,8 @@ def execute_model( "positions": model_input.input_positions, "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, + **(model_input.multi_modal_kwargs or {}), } - if (self.vision_language_config - and model_input.multi_modal_kwargs is not None): - execute_model_kwargs.update(model_input.multi_modal_kwargs) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 0e1bb1bfe273..d3a2643cb62f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -92,10 +92,9 @@ def execute_model( "positions": model_input.input_positions, "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, + **(model_input.multi_modal_kwargs or {}), } - if self.vision_language_config: - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - execute_model_kwargs.update({"image_input": multi_modal_kwargs}) + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bd30281471d1..530c631d5767 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,8 +3,8 @@ import time import warnings from collections import defaultdict -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -37,7 +37,8 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import supports_lora -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -83,7 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 @@ -356,8 +357,7 @@ def _prepare_model_input_tensors( context_lens: List[int] = [] query_lens: List[int] = [] block_tables: List[List[int]] = [] - multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list) + multi_modal_inputs_list: List[MultiModalInputs] = [] request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) decode_only = True num_prefills = 0 @@ -528,8 +528,7 @@ def _prepare_model_input_tensors( if mm_data: # Process multi-modal data mm_kwargs = self.multi_modal_input_mapper(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) + multi_modal_inputs_list.append(mm_kwargs) is_profile_run = _is_block_tables_empty( seq_group_metadata.block_tables) @@ -746,10 +745,8 @@ def _prepare_model_input_tensors( else: lora_mapping = None - multi_modal_kwargs = { - k: torch.cat(v, dim=0).to(self.device) - for k, v in multi_modal_kwargs_list.items() - } + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) request_ids_to_seq_ids = { seq_group_metadata.request_id: list(seq_group_metadata.seq_data.keys()) @@ -821,7 +818,11 @@ def profile_run(self) -> None: seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ .dummy_data_for_profiling(model_config, seq_len) - assert len(seq_data.prompt_token_ids) == seq_len + + # Having more tokens is over-conservative but otherwise fine + assert len(seq_data.prompt_token_ids) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(seq_data.prompt_token_ids)}") seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 8b96966be470..423f44085e31 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Union) import torch from torch import nn @@ -9,6 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad @@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -65,6 +69,10 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -76,13 +84,15 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[ + str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -102,6 +112,12 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + max_seq_len = max(seq_lens) assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, @@ -118,7 +134,11 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, seq_lens + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs) def _prepare_decode( self, @@ -184,8 +204,9 @@ def prepare_model_input( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_block_ids, - seq_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) @@ -203,7 +224,8 @@ def prepare_model_input( return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata) + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs) @torch.inference_mode() def execute_model( @@ -221,6 +243,7 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, + **(model_input.multi_modal_kwargs or {}), ) # Compute the logits. diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 336eaf814fb3..f064048888a7 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Optional, Tuple +from typing import List, Mapping, NamedTuple, Optional, Tuple import openvino as ov import torch @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.openvino import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import SamplerOutput, SequenceGroupMetadata logger = init_logger(__name__) @@ -23,7 +25,7 @@ class ModelInput(NamedTuple): attn_metadata: Optional[OpenVINOAttentionMetadata] seq_lens: List[int] query_lens: List[int] - multi_modal_input: Optional[torch.Tensor] + multi_modal_kwargs: Mapping[str, BatchedTensors] @classmethod def empty(cls, device): @@ -32,7 +34,7 @@ def empty(cls, device): attn_metadata=None, seq_lens=[], query_lens=[], - multi_modal_input=None) + multi_modal_kwargs={}) class OpenVINOModelRunner: @@ -78,6 +80,10 @@ def __init__( self.block_size, ) + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -108,6 +114,8 @@ def _prepare_model_input( seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + subsequence_begins: List[int] = [] block_indices: List[int] = [] block_indices_begins: List[int] = [] @@ -160,6 +168,11 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + block_table = seq_group_metadata.block_tables[seq_id] # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. @@ -251,22 +264,24 @@ def _prepare_model_input( block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + return ModelInput( input_tokens, input_positions, attn_metadata, seq_lens, query_lens, - None, + multi_modal_kwargs=multi_modal_kwargs, ) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, - SamplingMetadata, Optional[torch.Tensor], ]: - multi_modal_input = None - + SamplingMetadata, Mapping[str, BatchedTensors]]: # Prepare input tensors. ( input_tokens, @@ -274,7 +289,7 @@ def prepare_input_tensors( attn_metadata, seq_lens, query_lens, - multi_modal_input, + multi_modal_kwargs, ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( @@ -290,7 +305,7 @@ def prepare_input_tensors( input_positions, attn_metadata, sampling_metadata, - multi_modal_input, + multi_modal_kwargs, ) @torch.inference_mode() @@ -304,7 +319,7 @@ def execute_model( input_positions, attn_metadata, sampling_metadata, - multi_modal_input, + multi_modal_kwargs, ) = self.prepare_input_tensors(seq_group_metadata_list) model_executable = self.model @@ -313,9 +328,8 @@ def execute_model( "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, + **(multi_modal_kwargs or {}), } - if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index dd08536efc5f..4ea8e62cc1fd 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import List, Mapping, Optional, Tuple import numpy as np import torch @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) @@ -66,6 +68,10 @@ def __init__( False, ) + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + def load_model(self) -> None: self.device = self.device_config.device @@ -193,12 +199,14 @@ def warmup_model( def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ): + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] prompt_lens: List[int] = [] slot_mapping: List[List[int]] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -224,6 +232,11 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) num_prefill_tokens = sum(prompt_lens) @@ -261,17 +274,24 @@ def _prepare_prompt( block_tables=None, context_lens=None, ) - return input_tokens, input_positions, attn_metadata, prompt_lens + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_kwargs) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ): + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] batch_idx = 0 for seq_group_metadata in seq_group_metadata_list: @@ -297,6 +317,11 @@ def _prepare_decode( slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + batch_size = _get_padded_batch_size(batch_idx) num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings @@ -330,7 +355,12 @@ def _prepare_decode( block_tables=block_tables, context_lens=context_lens, ) - return input_tokens, input_positions, attn_metadata, input_lens + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, attn_metadata, input_lens, + multi_modal_kwargs) def _prepare_sample( self, @@ -483,6 +513,7 @@ def forward( kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]], t: torch.Tensor, p: torch.Tensor, num_samples: int, @@ -496,6 +527,8 @@ def forward( memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. + multi_modal_kwargs: Keyword arguments from multi-modal data to + pass to the model. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. """ @@ -540,6 +573,7 @@ def forward( position_ids, kv_caches, attn_metadata, + **(multi_modal_kwargs or {}), ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e652f1b1042e..f4fc42328027 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Type, Union) import torch import torch.nn as nn @@ -9,10 +10,13 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, +from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata @@ -44,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_input: Optional[Dict[str, torch.Tensor]] = None + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -116,6 +120,10 @@ def __init__( self.block_size, ) + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -156,12 +164,26 @@ def profile_run(self) -> None: # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. + model_config = self.model_config + vlm_config = self.vision_language_config + + if vlm_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / vlm_config.image_feature_size)) + for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = None + seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + .dummy_data_for_profiling(model_config, seq_len) + + # Having more tokens is over-conservative but otherwise fine + assert len(seq_data.prompt_token_ids) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(seq_data.prompt_token_ids)}") + seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -194,7 +216,7 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: - multi_modal_input = None + multi_modal_kwargs = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -202,7 +224,7 @@ def prepare_model_input( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input + multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, @@ -223,6 +245,7 @@ def prepare_model_input( "input_positions": input_positions, "selected_token_indices": sampling_metadata.selected_token_indices, + "multi_modal_kwargs": multi_modal_kwargs, } metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) @@ -232,6 +255,7 @@ def prepare_model_input( input_positions = metadata_dict.pop("input_positions") selected_token_indices = metadata_dict.pop( "selected_token_indices") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, @@ -244,7 +268,7 @@ def prepare_model_input( input_positions=input_positions, attn_metadata=attn_metadata, sampling_metadata=sampling_metadata, - multi_modal_input=multi_modal_input) + multi_modal_kwargs=multi_modal_kwargs) def _prepare_decode( self, @@ -350,10 +374,8 @@ def execute_model( "positions": model_input.input_positions, "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, + **(model_input.multi_modal_kwargs or {}), } - if self.vision_language_config: - execute_model_kwargs.update( - {"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) @@ -376,13 +398,13 @@ def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - Optional[torch.Tensor]]: + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -403,9 +425,10 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -435,15 +458,6 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, @@ -475,5 +489,9 @@ def _prepare_prompt( num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input) + multi_modal_kwargs)