Skip to content

Commit

Permalink
FEAT: Support qwen vl chat (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Dec 29, 2023
1 parent 817b622 commit 6ffee49
Show file tree
Hide file tree
Showing 12 changed files with 829 additions and 29 deletions.
Empty file removed create_test_data.py
Empty file.
39 changes: 39 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..model.embedding import EmbeddingModelSpec
from ..model.image import ImageModelFamilyV1
from ..model.llm import LLMFamilyV1
from ..model.multimodal import LVLMFamilyV1
from ..model.rerank import RerankModelSpec
from .worker import WorkerActor

Expand Down Expand Up @@ -222,6 +223,25 @@ def _to_image_model_reg(
"is_builtin": is_builtin,
}

def _to_multimodal_reg(
self, model_family: "LVLMFamilyV1", is_builtin: bool
) -> Dict[str, Any]:
from ..model.llm import get_cache_status

if self.is_local_deployment():
specs = []
# TODO: does not work when the supervisor and worker are running on separate nodes.
for spec in model_family.model_specs:
cache_status = get_cache_status(model_family, spec)
specs.append({**spec.dict(), "cache_status": cache_status})
return {
**model_family.dict(),
"is_builtin": is_builtin,
"model_specs": specs,
}
else:
return {**model_family.dict(), "is_builtin": is_builtin}

@log_sync(logger=logger)
def list_model_registrations(
self, model_type: str, detailed: bool = False
Expand Down Expand Up @@ -302,6 +322,18 @@ def sort_helper(item):
{"model_name": model_spec.model_name, "is_builtin": False}
)

ret.sort(key=sort_helper)
return ret
elif model_type == "multimodal":
from ..model.multimodal import BUILTIN_LVLM_FAMILIES

ret = []
for family in BUILTIN_LVLM_FAMILIES:
if detailed:
ret.append(self._to_multimodal_reg(family, True))
else:
ret.append({"model_name": family.model_name, "is_builtin": True})

ret.sort(key=sort_helper)
return ret
else:
Expand Down Expand Up @@ -342,6 +374,13 @@ def get_model_registration(self, model_type: str, model_name: str) -> Any:
if f.model_name == model_name:
return f
raise ValueError(f"Model {model_name} not found")
elif model_type == "multimodal":
from ..model.multimodal import BUILTIN_LVLM_FAMILIES

for f in BUILTIN_LVLM_FAMILIES:
if f.model_name == model_name:
return f
raise ValueError(f"Model {model_name} not found")
else:
raise ValueError(f"Unsupported model type: {model_type}")

Expand Down
6 changes: 6 additions & 0 deletions xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def create_model_instance(
from .embedding.core import create_embedding_model_instance
from .image.core import create_image_model_instance
from .llm.core import create_llm_model_instance
from .multimodal.core import create_multimodal_model_instance
from .rerank.core import create_rerank_model_instance

if model_type == "LLM":
Expand Down Expand Up @@ -74,5 +75,10 @@ def create_model_instance(
return create_rerank_model_instance(
subpool_addr, devices, model_uid, model_name, **kwargs
)
elif model_type == "multimodal":
kwargs.pop("trust_remote_code", None)
return create_multimodal_model_instance(
subpool_addr, devices, model_uid, model_name, **kwargs
)
else:
raise ValueError(f"Unsupported model type: {model_type}.")
30 changes: 2 additions & 28 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PytorchGenerateConfig,
PytorchModelConfig,
)
from ...utils import select_device
from ..core import LLM
from ..llm_family import LLMFamilyV1, LLMSpecV1
from ..utils import ChatModelMixin
Expand Down Expand Up @@ -122,7 +123,7 @@ def load(self):
quantization = self.quantization
num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0
device = self._pytorch_model_config.get("device", "auto")
self._pytorch_model_config["device"] = self._select_device(device)
self._pytorch_model_config["device"] = select_device(device)
self._device = self._pytorch_model_config["device"]

if self._device == "cpu":
Expand Down Expand Up @@ -185,33 +186,6 @@ def load(self):
self._model.to(self._device)
logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")

def _select_device(self, device: str) -> str:
try:
import torch
except ImportError:
raise ImportError(
f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
)

if device == "auto":
# When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
elif device == "cuda":
if not torch.cuda.is_available():
raise ValueError("cuda is unavailable in your environment")
elif device == "mps":
if not torch.backends.mps.is_available():
raise ValueError("mps is unavailable in your environment")
elif device == "cpu":
pass
else:
raise ValueError(f"Device {device} is not supported in temporary")
return device

@classmethod
def match(
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
Expand Down
3 changes: 2 additions & 1 deletion xinference/model/llm/pytorch/spec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Iterator, List, Optional, Union

from ....types import Completion, CompletionChunk, Embedding
from ...utils import select_device
from .. import LLMFamilyV1, LLMSpecV1
from .core import PytorchChatModel, PytorchGenerateConfig, PytorchModelConfig

Expand Down Expand Up @@ -85,7 +86,7 @@ def load(self):

num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0
device = self._pytorch_model_config.get("device", "auto")
self._pytorch_model_config["device"] = self._select_device(device)
self._pytorch_model_config["device"] = select_device(device)
self._device = self._pytorch_model_config["device"]

if self._device == "cpu":
Expand Down
45 changes: 45 additions & 0 deletions xinference/model/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import json
import os

from .core import (
BUILTIN_LVLM_FAMILIES,
BUILTIN_MODELSCOPE_LVLM_FAMILIES,
MODEL_CLASSES,
MODEL_NAME_TO_REVISION,
LVLMFamilyV1,
LVLMPromptStyleV1,
)
from .qwen_vl import QwenVLChat

MODEL_CLASSES.append(QwenVLChat)


def _install():
json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "model_spec.json"
)
for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
model_family = LVLMFamilyV1.parse_obj(json_obj)
BUILTIN_LVLM_FAMILIES.append(model_family)
for model_spec in model_family.model_specs:
MODEL_NAME_TO_REVISION[model_family.model_name].append(
model_spec.model_revision
)


_install()
Loading

0 comments on commit 6ffee49

Please sign in to comment.