From 6ffee492d79c1bf9a6bdef8291dc0a56117abe06 Mon Sep 17 00:00:00 2001 From: codingl2k1 <138426806+codingl2k1@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:58:30 +0800 Subject: [PATCH] FEAT: Support qwen vl chat (#829) --- create_test_data.py | 0 xinference/core/supervisor.py | 39 ++ xinference/model/core.py | 6 + xinference/model/llm/pytorch/core.py | 30 +- xinference/model/llm/pytorch/spec_model.py | 3 +- xinference/model/multimodal/__init__.py | 45 ++ xinference/model/multimodal/core.py | 460 ++++++++++++++++++ xinference/model/multimodal/model_spec.json | 34 ++ xinference/model/multimodal/qwen_vl.py | 120 +++++ xinference/model/multimodal/tests/__init__.py | 13 + .../model/multimodal/tests/test_multimodal.py | 80 +++ xinference/model/utils.py | 28 ++ 12 files changed, 829 insertions(+), 29 deletions(-) delete mode 100644 create_test_data.py create mode 100644 xinference/model/multimodal/__init__.py create mode 100644 xinference/model/multimodal/core.py create mode 100644 xinference/model/multimodal/model_spec.json create mode 100644 xinference/model/multimodal/qwen_vl.py create mode 100644 xinference/model/multimodal/tests/__init__.py create mode 100644 xinference/model/multimodal/tests/test_multimodal.py diff --git a/create_test_data.py b/create_test_data.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index a0e83229ab..7554ca66be 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -37,6 +37,7 @@ from ..model.embedding import EmbeddingModelSpec from ..model.image import ImageModelFamilyV1 from ..model.llm import LLMFamilyV1 + from ..model.multimodal import LVLMFamilyV1 from ..model.rerank import RerankModelSpec from .worker import WorkerActor @@ -222,6 +223,25 @@ def _to_image_model_reg( "is_builtin": is_builtin, } + def _to_multimodal_reg( + self, model_family: "LVLMFamilyV1", is_builtin: bool + ) -> Dict[str, Any]: + from ..model.llm import get_cache_status + + if self.is_local_deployment(): + specs = [] + # TODO: does not work when the supervisor and worker are running on separate nodes. + for spec in model_family.model_specs: + cache_status = get_cache_status(model_family, spec) + specs.append({**spec.dict(), "cache_status": cache_status}) + return { + **model_family.dict(), + "is_builtin": is_builtin, + "model_specs": specs, + } + else: + return {**model_family.dict(), "is_builtin": is_builtin} + @log_sync(logger=logger) def list_model_registrations( self, model_type: str, detailed: bool = False @@ -302,6 +322,18 @@ def sort_helper(item): {"model_name": model_spec.model_name, "is_builtin": False} ) + ret.sort(key=sort_helper) + return ret + elif model_type == "multimodal": + from ..model.multimodal import BUILTIN_LVLM_FAMILIES + + ret = [] + for family in BUILTIN_LVLM_FAMILIES: + if detailed: + ret.append(self._to_multimodal_reg(family, True)) + else: + ret.append({"model_name": family.model_name, "is_builtin": True}) + ret.sort(key=sort_helper) return ret else: @@ -342,6 +374,13 @@ def get_model_registration(self, model_type: str, model_name: str) -> Any: if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") + elif model_type == "multimodal": + from ..model.multimodal import BUILTIN_LVLM_FAMILIES + + for f in BUILTIN_LVLM_FAMILIES: + if f.model_name == model_name: + return f + raise ValueError(f"Model {model_name} not found") else: raise ValueError(f"Unsupported model type: {model_type}") diff --git a/xinference/model/core.py b/xinference/model/core.py index 9414c504e4..bcc465247c 100644 --- a/xinference/model/core.py +++ b/xinference/model/core.py @@ -44,6 +44,7 @@ def create_model_instance( from .embedding.core import create_embedding_model_instance from .image.core import create_image_model_instance from .llm.core import create_llm_model_instance + from .multimodal.core import create_multimodal_model_instance from .rerank.core import create_rerank_model_instance if model_type == "LLM": @@ -74,5 +75,10 @@ def create_model_instance( return create_rerank_model_instance( subpool_addr, devices, model_uid, model_name, **kwargs ) + elif model_type == "multimodal": + kwargs.pop("trust_remote_code", None) + return create_multimodal_model_instance( + subpool_addr, devices, model_uid, model_name, **kwargs + ) else: raise ValueError(f"Unsupported model type: {model_type}.") diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index b943098586..db75c721d0 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -29,6 +29,7 @@ PytorchGenerateConfig, PytorchModelConfig, ) +from ...utils import select_device from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import ChatModelMixin @@ -122,7 +123,7 @@ def load(self): quantization = self.quantization num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0 device = self._pytorch_model_config.get("device", "auto") - self._pytorch_model_config["device"] = self._select_device(device) + self._pytorch_model_config["device"] = select_device(device) self._device = self._pytorch_model_config["device"] if self._device == "cpu": @@ -185,33 +186,6 @@ def load(self): self._model.to(self._device) logger.debug(f"Model Memory: {self._model.get_memory_footprint()}") - def _select_device(self, device: str) -> str: - try: - import torch - except ImportError: - raise ImportError( - f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" - ) - - if device == "auto": - # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False - if torch.cuda.is_available(): - return "cuda" - elif torch.backends.mps.is_available(): - return "mps" - return "cpu" - elif device == "cuda": - if not torch.cuda.is_available(): - raise ValueError("cuda is unavailable in your environment") - elif device == "mps": - if not torch.backends.mps.is_available(): - raise ValueError("mps is unavailable in your environment") - elif device == "cpu": - pass - else: - raise ValueError(f"Device {device} is not supported in temporary") - return device - @classmethod def match( cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str diff --git a/xinference/model/llm/pytorch/spec_model.py b/xinference/model/llm/pytorch/spec_model.py index e438bbb264..a66f6fbfc1 100644 --- a/xinference/model/llm/pytorch/spec_model.py +++ b/xinference/model/llm/pytorch/spec_model.py @@ -17,6 +17,7 @@ from typing import Iterator, List, Optional, Union from ....types import Completion, CompletionChunk, Embedding +from ...utils import select_device from .. import LLMFamilyV1, LLMSpecV1 from .core import PytorchChatModel, PytorchGenerateConfig, PytorchModelConfig @@ -85,7 +86,7 @@ def load(self): num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0 device = self._pytorch_model_config.get("device", "auto") - self._pytorch_model_config["device"] = self._select_device(device) + self._pytorch_model_config["device"] = select_device(device) self._device = self._pytorch_model_config["device"] if self._device == "cpu": diff --git a/xinference/model/multimodal/__init__.py b/xinference/model/multimodal/__init__.py new file mode 100644 index 0000000000..bae4627739 --- /dev/null +++ b/xinference/model/multimodal/__init__.py @@ -0,0 +1,45 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import json +import os + +from .core import ( + BUILTIN_LVLM_FAMILIES, + BUILTIN_MODELSCOPE_LVLM_FAMILIES, + MODEL_CLASSES, + MODEL_NAME_TO_REVISION, + LVLMFamilyV1, + LVLMPromptStyleV1, +) +from .qwen_vl import QwenVLChat + +MODEL_CLASSES.append(QwenVLChat) + + +def _install(): + json_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "model_spec.json" + ) + for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")): + model_family = LVLMFamilyV1.parse_obj(json_obj) + BUILTIN_LVLM_FAMILIES.append(model_family) + for model_spec in model_family.model_specs: + MODEL_NAME_TO_REVISION[model_family.model_name].append( + model_spec.model_revision + ) + + +_install() diff --git a/xinference/model/multimodal/core.py b/xinference/model/multimodal/core.py new file mode 100644 index 0000000000..678c8583b1 --- /dev/null +++ b/xinference/model/multimodal/core.py @@ -0,0 +1,460 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +import os +from abc import abstractmethod +from collections import defaultdict +from typing import Dict, Iterator, List, Literal, Optional, Tuple, Type, Union + +from pydantic import BaseModel, validator + +from ...constants import XINFERENCE_CACHE_DIR +from ...core.utils import parse_replica_model_uid +from ...types import ChatCompletion, ChatCompletionChunk +from ..core import ModelDescription +from ..utils import ( + download_from_modelscope, + is_model_cached, + retry_download, + symlink_local_file, + valid_model_revision, +) + +logger = logging.getLogger(__name__) + +DEFAULT_CONTEXT_LENGTH = 2048 +# Used for check whether the model is cached. +# Init when registering all the builtin models. +MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list) + + +class LVLMSpecV1(BaseModel): + model_format: Literal["pytorch", "gptq"] + # Must in order that `str` first, then `int` + model_size_in_billions: Union[str, int] + quantizations: List[str] + model_id: str + model_hub: str = "huggingface" + model_uri: Optional[str] + model_revision: Optional[str] + + @validator("model_size_in_billions", pre=False) + def validate_model_size_with_radix(cls, v: object) -> object: + if isinstance(v, str): + if ( + "_" in v + ): # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18 + return v + else: + return int(v) + return v + + +class LVLMPromptStyleV1(BaseModel): + style_name: str + system_prompt: str = "" + roles: List[str] + + +class LVLMFamilyV1(BaseModel): + version: Literal[1] + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH + model_name: str + model_lang: List[str] + model_ability: List[Literal["chat"]] + model_description: Optional[str] + model_specs: List["LVLMSpecV1"] + prompt_style: Optional["LVLMPromptStyleV1"] + + +class LVLMDescription(ModelDescription): + def __init__( + self, + address: Optional[str], + devices: Optional[List[str]], + model_family: "LVLMFamilyV1", + model_spec: "LVLMSpecV1", + quantization: Optional[str], + ): + super().__init__(address, devices) + self._model_family = model_family + self._model_spec = model_spec + self._quantization = quantization + + def to_dict(self): + return { + "model_type": "LVLM", + "address": self.address, + "accelerators": self.devices, + "model_name": self._model_family.model_name, + "model_lang": self._model_family.model_lang, + "model_ability": self._model_family.model_ability, + "model_description": self._model_family.model_description, + "model_format": self._model_spec.model_format, + "model_size_in_billions": self._model_spec.model_size_in_billions, + "quantization": self._quantization, + "model_hub": self._model_spec.model_hub, + "revision": self._model_spec.model_revision, + "context_length": self._model_family.context_length, + } + + +class LVLM(abc.ABC): + def __init__( + self, + replica_model_uid: str, + model_family: "LVLMFamilyV1", + model_spec: "LVLMSpecV1", + quantization: str, + model_path: str, + kwargs: Dict, + ): + self.model_uid, self.replica, self.rep_id = parse_replica_model_uid( + replica_model_uid + ) + self.model_family = model_family + self.model_spec = model_spec + self.quantization = quantization + self.model_path = model_path + self.kwargs = kwargs + logger.info("Init model %s with kwargs: %s", self.model_uid, kwargs) + + @abstractmethod + def load(self): + raise NotImplementedError + + @abstractmethod + def chat( + self, + prompt: str, + system_prompt: Optional[str] = None, + chat_history: Optional[List[Dict]] = None, + generate_config: Optional[Dict] = None, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + raise NotImplementedError + + @classmethod + def match( + cls, model_family: "LVLMFamilyV1", model_spec: "LVLMSpecV1", quantization: str + ) -> bool: + raise NotImplementedError + + +BUILTIN_LVLM_FAMILIES: List["LVLMFamilyV1"] = [] +BUILTIN_MODELSCOPE_LVLM_FAMILIES: List["LVLMFamilyV1"] = [] + + +def match_multimodal( + model_name: str, + model_format: Optional[str] = None, + model_size_in_billions: Optional[int] = None, + quantization: Optional[str] = None, +) -> Optional[Tuple[LVLMFamilyV1, LVLMSpecV1, str]]: + """ + Find an multimodal family, spec, and quantization that satisfy given criteria. + """ + + def _match_quantization(q: Union[str, None], quantizations: List[str]): + # Currently, the quantization name could include both uppercase and lowercase letters, + # so it is necessary to ensure that the case sensitivity does not + # affect the matching results. + if q is None: + return q + for quant in quantizations: + if q.lower() == quant.lower(): + return quant + + def _apply_format_to_model_id(spec: LVLMSpecV1, q: str) -> LVLMSpecV1: + # Different quantized versions of some models use different model ids, + # Here we check the `{}` in the model id to format the id. + if "{" in spec.model_id: + spec.model_id = spec.model_id.format(quantization=q) + return spec + + if download_from_modelscope(): + all_families = BUILTIN_MODELSCOPE_LVLM_FAMILIES + BUILTIN_LVLM_FAMILIES + else: + all_families = BUILTIN_LVLM_FAMILIES + + for family in all_families: + if model_name != family.model_name: + continue + for spec in family.model_specs: + matched_quantization = _match_quantization(quantization, spec.quantizations) + if ( + model_format + and model_format != spec.model_format + or model_size_in_billions + and model_size_in_billions != spec.model_size_in_billions + or quantization + and matched_quantization is None + ): + continue + if quantization: + return ( + family, + _apply_format_to_model_id(spec, matched_quantization), + matched_quantization, + ) + else: + return family, _apply_format_to_model_id(spec, "none"), "none" + return None + + +def create_multimodal_model_instance( + subpool_addr: str, + devices: List[str], + model_uid: str, + model_name: str, + model_format: Optional[str] = None, + model_size_in_billions: Optional[int] = None, + quantization: Optional[str] = None, + **kwargs, +) -> Tuple[LVLM, LVLMDescription]: + match_result = match_multimodal( + model_name, + model_format, + model_size_in_billions, + quantization, + ) + if not match_result: + raise ValueError( + f"Model not found, name: {model_name}, format: {model_format}," + f" size: {model_size_in_billions}, quantization: {quantization}" + ) + model_family, model_spec, quantization = match_result + + assert quantization is not None + save_path = cache(model_family, model_spec, quantization) + + cls = match_cls(model_family, model_spec, quantization) + logger.debug(f"Launching {model_uid} with {cls.__name__}") + + model = cls(model_uid, model_family, model_spec, quantization, save_path, kwargs) + return model, LVLMDescription( + subpool_addr, devices, model_family, model_spec, quantization + ) + + +MODEL_CLASSES: List[Type[LVLM]] = [] + + +def match_cls( + model_family: LVLMFamilyV1, model_spec: "LVLMSpecV1", quantization: str +) -> Type[LVLM]: + """ + Find an multimodal implementation for given multimodal family and spec. + """ + for cls in MODEL_CLASSES: + if cls.match(model_family, model_spec, quantization): + return cls + raise Exception(f"Model {model_family.model_name} is not supported") + + +def _get_cache_dir( + model_family: LVLMFamilyV1, + model_spec: "LVLMSpecV1", + create_if_not_exist=True, +): + cache_dir_name = ( + f"{model_family.model_name}-{model_spec.model_format}" + f"-{model_spec.model_size_in_billions}b" + ) + cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name)) + if create_if_not_exist and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def _get_meta_path( + cache_dir: str, + model_format: str, + model_hub: str, + quantization: Optional[str] = None, +): + if model_format == "pytorch": + if model_hub == "huggingface": + return os.path.join(cache_dir, "__valid_download") + else: + return os.path.join(cache_dir, f"__valid_download_{model_hub}") + elif model_format in ["ggmlv3", "ggufv2", "gptq"]: + assert quantization is not None + if model_hub == "huggingface": + return os.path.join(cache_dir, f"__valid_download_{quantization}") + else: + return os.path.join( + cache_dir, f"__valid_download_{model_hub}_{quantization}" + ) + else: + raise ValueError(f"Unsupported format: {model_format}") + + +def _skip_download( + cache_dir: str, + model_format: str, + model_hub: str, + model_revision: Optional[str], + quantization: Optional[str] = None, +) -> bool: + if model_format == "pytorch": + model_hub_to_meta_path = { + "huggingface": _get_meta_path( + cache_dir, model_format, "huggingface", quantization + ), + "modelscope": _get_meta_path( + cache_dir, model_format, "modelscope", quantization + ), + } + if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision): + logger.info(f"Cache {cache_dir} exists") + return True + else: + for hub, meta_path in model_hub_to_meta_path.items(): + if hub != model_hub and os.path.exists(meta_path): + # PyTorch models from modelscope can also be loaded by transformers. + logger.warning(f"Cache {cache_dir} exists, but it was from {hub}") + return True + return False + else: + raise ValueError(f"Unsupported format: {model_format}") + + +def _generate_meta_file( + meta_path: str, + model_family: "LVLMFamilyV1", + model_spec: "LVLMSpecV1", + quantization: Optional[str] = None, +): + assert not valid_model_revision( + meta_path, model_spec.model_revision + ), f"meta file {meta_path} should not be valid" + with open(meta_path, "w") as f: + import json + + desc = LVLMDescription(None, None, model_family, model_spec, quantization) + json.dump(desc.to_dict(), f) + + +def cache_from_modelscope( + model_family: LVLMFamilyV1, + model_spec: "LVLMSpecV1", + quantization: Optional[str] = None, +) -> str: + """ + Cache model from Modelscope. Return the cache directory. + """ + from modelscope.hub.snapshot_download import snapshot_download + + cache_dir = _get_cache_dir(model_family, model_spec) + if _skip_download( + cache_dir, + model_spec.model_format, + model_spec.model_hub, + model_spec.model_revision, + quantization, + ): + return cache_dir + + if model_spec.model_format in ["pytorch", "gptq"]: + download_dir = retry_download( + snapshot_download, + model_family.model_name, + { + "model_size": model_spec.model_size_in_billions, + "model_format": model_spec.model_format, + }, + model_spec.model_id, + revision=model_spec.model_revision, + ) + for subdir, dirs, files in os.walk(download_dir): + for file in files: + relpath = os.path.relpath(os.path.join(subdir, file), download_dir) + symlink_local_file(os.path.join(subdir, file), cache_dir, relpath) + else: + raise ValueError(f"Unsupported format: {model_spec.model_format}") + + meta_path = _get_meta_path( + cache_dir, model_spec.model_format, model_spec.model_hub, quantization + ) + _generate_meta_file(meta_path, model_family, model_spec, quantization) + + return cache_dir + + +def cache_from_huggingface( + model_family: LVLMFamilyV1, + model_spec: "LVLMSpecV1", + quantization: Optional[str] = None, +) -> str: + """ + Cache model from Hugging Face. Return the cache directory. + """ + import huggingface_hub + + cache_dir = _get_cache_dir(model_family, model_spec) + if _skip_download( + cache_dir, + model_spec.model_format, + model_spec.model_hub, + model_spec.model_revision, + quantization, + ): + return cache_dir + + if model_spec.model_format in ["pytorch"]: + assert isinstance(model_spec, LVLMSpecV1) + retry_download( + huggingface_hub.snapshot_download, + model_family.model_name, + { + "model_size": model_spec.model_size_in_billions, + "model_format": model_spec.model_format, + }, + model_spec.model_id, + revision=model_spec.model_revision, + local_dir=cache_dir, + local_dir_use_symlinks=True, + ) + else: + raise ValueError(f"Unsupported model format: {model_spec.model_format}") + + meta_path = _get_meta_path( + cache_dir, model_spec.model_format, model_spec.model_hub, quantization + ) + _generate_meta_file(meta_path, model_family, model_spec, quantization) + + return cache_dir + + +def cache( + model_family: LVLMFamilyV1, + model_spec: "LVLMSpecV1", + quantization: Optional[str] = None, +) -> str: + if model_spec.model_hub == "huggingface": + logger.info(f"Caching from Hugging Face: {model_spec.model_id}") + return cache_from_huggingface(model_family, model_spec, quantization) + elif model_spec.model_hub == "modelscope": + logger.info(f"Caching from Modelscope: {model_spec.model_id}") + return cache_from_modelscope(model_family, model_spec, quantization) + else: + raise ValueError(f"Unknown model hub: {model_spec.model_hub}") + + +def get_cache_status( + model_spec: LVLMSpecV1, +) -> bool: + return is_model_cached(model_spec, MODEL_NAME_TO_REVISION) diff --git a/xinference/model/multimodal/model_spec.json b/xinference/model/multimodal/model_spec.json new file mode 100644 index 0000000000..07af7f2f19 --- /dev/null +++ b/xinference/model/multimodal/model_spec.json @@ -0,0 +1,34 @@ +[ + { + "version": 1, + "context_length": 4096, + "model_name": "qwen-vl-chat", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "chat" + ], + "model_description": "Qwen-VL-Chat supports more flexible interaction, such as multiple image inputs, multi-round question answering, and creative capabilities.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "none" + ], + "model_id": "Qwen/Qwen-VL-Chat", + "model_revision": "6665c780ade5ff3f08853b4262dcb9c8f9598d42" + } + ], + "prompt_style": { + "style_name": "QWEN", + "system_prompt": "You are a helpful assistant.", + "roles": [ + "user", + "assistant" + ] + } + } +] diff --git a/xinference/model/multimodal/qwen_vl.py b/xinference/model/multimodal/qwen_vl.py new file mode 100644 index 0000000000..55e29fe182 --- /dev/null +++ b/xinference/model/multimodal/qwen_vl.py @@ -0,0 +1,120 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import time +import uuid +from typing import Dict, Iterator, List, Optional, Union + +from ...types import ( + ChatCompletion, + ChatCompletionChoice, + ChatCompletionChunk, + CompletionUsage, +) +from ..utils import select_device +from .core import LVLM, LVLMFamilyV1, LVLMSpecV1 + + +class QwenVLChat(LVLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._tokenizer = None + self._model = None + + @classmethod + def match( + cls, model_family: "LVLMFamilyV1", model_spec: "LVLMSpecV1", quantization: str + ) -> bool: + if "qwen" in model_family.model_name: + return True + return False + + def load(self): + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.generation import GenerationConfig + + device = self.kwargs.get("device", "auto") + device = select_device(device) + + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True, + code_revision=self.model_spec.model_revision, + ) + self._model = AutoModelForCausalLM.from_pretrained( + self.model_path, + device_map=device, + trust_remote_code=True, + code_revision=self.model_spec.model_revision, + ).eval() + # Specify hyperparameters for generation + self._model.generation_config = GenerationConfig.from_pretrained( + self.model_path, + trust_remote_code=True, + code_revision=self.model_spec.model_revision, + ) + + def _message_content_to_qwen(self, content) -> str: + if not isinstance(content, str): + content = [ + {"image": c["image_url"]["url"], "type": "image"} + if c.get("type") == "image_url" + else c + for c in content + ] + content = sorted(content, key=operator.itemgetter("type")) + return self._tokenizer.from_list_format(content) + return content + + def chat( + self, + prompt: Union[str, List[Dict]], + system_prompt: Optional[str] = None, + chat_history: Optional[List[Dict]] = None, + generate_config: Optional[Dict] = None, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + prompt = self._message_content_to_qwen(prompt) + # Convert openai history to qwen vl history + qwen_history = [] + query_to_response: List = [] + for h in chat_history or []: + role = h["role"] + content = self._message_content_to_qwen(h["content"]) + if len(query_to_response) == 0 and role == "user": + query_to_response.append(content) + if len(query_to_response) == 1 and role == "assistant": + query_to_response.append(content) + if len(query_to_response) == 2: + qwen_history.append(query_to_response) + query_to_response = [] + response, history = self._model.chat( + self._tokenizer, query=prompt, history=qwen_history + ) + return ChatCompletion( + id="chat" + str(uuid.uuid1()), + object="chat.completion", + created=int(time.time()), + model=self.model_uid, + choices=[ + ChatCompletionChoice( + index=0, + message={"role": "assistant", "content": response}, + finish_reason="stop", + ) + ], + usage=CompletionUsage( + prompt_tokens=-1, completion_tokens=-1, total_tokens=-1 + ), + ) diff --git a/xinference/model/multimodal/tests/__init__.py b/xinference/model/multimodal/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/model/multimodal/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/multimodal/tests/test_multimodal.py b/xinference/model/multimodal/tests/test_multimodal.py new file mode 100644 index 0000000000..38317049b8 --- /dev/null +++ b/xinference/model/multimodal/tests/test_multimodal.py @@ -0,0 +1,80 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + + +@pytest.mark.skip(reason="Cost too many resources.") +def test_restful_api_for_qwen_vl(setup): + endpoint, _ = setup + from ....client import Client + + client = Client(endpoint) + + model_uid = client.launch_model( + model_uid="my_controlnet", + model_name="qwen-vl-chat", + model_type="multimodal", + device="cpu", + ) + model = client.get_model(model_uid) + assert model + + # openai client + import openai + + client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1") + completion = client.chat.completions.create( + model=model_uid, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + }, + }, + ], + } + ], + ) + assert "grass" in completion.choices[0].message.content + assert "tree" in completion.choices[0].message.content + assert "sky" in completion.choices[0].message.content + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "这是什么?"}, + { + "type": "image_url", + "image_url": { + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + }, + ], + } + ] + completion = client.chat.completions.create(model=model_uid, messages=messages) + assert "女" in completion.choices[0].message.content + assert "狗" in completion.choices[0].message.content + assert "沙滩" in completion.choices[0].message.content + messages.append(completion.choices[0].message.model_dump()) + messages.append({"role": "user", "content": "框出图中击掌的位置"}) + completion = client.chat.completions.create(model=model_uid, messages=messages) + assert "击掌" in completion.choices[0].message.content + assert "" in completion.choices[0].message.content + assert "" in completion.choices[0].message.content diff --git a/xinference/model/utils.py b/xinference/model/utils.py index dafc25fd2a..490e1b5cb7 100644 --- a/xinference/model/utils.py +++ b/xinference/model/utils.py @@ -255,3 +255,31 @@ def _patched_resolve_trust_remote_code(*args, **kwargs): resolve_trust_remote_code.__code__ = ( _patched_resolve_trust_remote_code.__code__ ) + + +def select_device(device): + try: + import torch + except ImportError: + raise ImportError( + f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" + ) + + if device == "auto": + # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + return "cpu" + elif device == "cuda": + if not torch.cuda.is_available(): + raise ValueError("cuda is unavailable in your environment") + elif device == "mps": + if not torch.backends.mps.is_available(): + raise ValueError("mps is unavailable in your environment") + elif device == "cpu": + pass + else: + raise ValueError(f"Device {device} is not supported in temporary") + return device