Skip to content

Commit

Permalink
update vllm backend to support offline and online modes and arbitary …
Browse files Browse the repository at this point in the history
…engine args
  • Loading branch information
IlyasMoutawwakil committed Jul 22, 2024
1 parent 6351e36 commit 94f9961
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 103 deletions.
22 changes: 6 additions & 16 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class BackendConfig(ABC):
processor: Optional[str] = None

device: Optional[str] = None
device_ids: Optional[str] = None
# yes we use a string here instead of a list
# we use a string here instead of a list
# because it's easier to pass in a yaml or from cli
# and it's consistent with GPU environment variables
device_ids: Optional[str] = None

seed: int = 42
inter_op_num_threads: Optional[int] = None
Expand All @@ -44,33 +44,23 @@ class BackendConfig(ABC):
# processor kwargs that are added to its init method/constructor
processor_kwargs: Dict[str, Any] = field(default_factory=dict)

# deprecated
hub_kwargs: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if self.model is None:
raise ValueError("`model` must be specified.")

if self.processor is None:
self.processor = self.model

if self.hub_kwargs:
LOGGER.warning(
"`hub_kwargs` is deprecated and will be removed in future versions."
"Please use `model_kwargs` and `processor_kwargs` seperately."
)
self.model_kwargs = {**self.model_kwargs, **self.hub_kwargs}
self.processor_kwargs = {**self.processor_kwargs, **self.hub_kwargs}

# TODO: add cache_dir, token, etc. to these methods
if self.task is None:
self.task = infer_task_from_model_name_or_path(self.model, self.hub_kwargs.get("revision", None))
self.task = infer_task_from_model_name_or_path(self.model, self.model_kwargs.get("revision", None))

if self.library is None:
self.library = infer_library_from_model_name_or_path(self.model, self.hub_kwargs.get("revision", None))
self.library = infer_library_from_model_name_or_path(self.model, self.model_kwargs.get("revision", None))

if self.model_type is None:
self.model_type = infer_model_type_from_model_name_or_path(
self.model, self.hub_kwargs.get("revision", None)
self.model, self.model_kwargs.get("revision", None)
)

if self.device is None:
Expand Down
8 changes: 7 additions & 1 deletion optimum_benchmark/backends/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -107,7 +108,12 @@ def extract_transformers_shapes_from_artifacts(
processor_dict = {k: v for k, v in processor.to_dict().items() if v is not None}
artifacts_dict.update(processor_dict)
elif processor is not None:
processor_dict = {k: getattr(processor, k) for k in dir(processor) if isinstance(getattr(processor, k), int)}
try:
processor_dict = {
k: getattr(processor, k) for k in dir(processor) if isinstance(getattr(processor, k), int)
}
except Exception:
warnings.warn(f"Could not extract shapes from processor {processor}")

shapes = {}

Expand Down
112 changes: 50 additions & 62 deletions optimum_benchmark/backends/vllm/backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import os
from tempfile import TemporaryDirectory
from typing import Any, Dict
from typing import Any, Dict, Union

import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from safetensors.torch import save_file
from vllm import LLM, SamplingParams
from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
Expand All @@ -15,6 +16,7 @@

class VLLMBackend(Backend[VLLMConfig]):
NAME: str = "vllm"
pretrained_model: Union[LLMEngine, AsyncLLMEngine]

def __init__(self, config: VLLMConfig) -> None:
super().__init__(config)
Expand Down Expand Up @@ -97,34 +99,10 @@ def load_model_with_no_weights(self) -> None:
self.config.model = original_model

def load_model_from_pretrained(self) -> None:
self.pretrained_model = LLM(
model=self.config.model,
# tokenizer
tokenizer=self.config.processor,
tokenizer_mode=self.config.tokenizer_mode,
skip_tokenizer_init=self.config.skip_tokenizer_init,
# device
device=self.config.device,
# parallelism
tensor_parallel_size=self.config.tensor_parallel_size,
# precision
quantization=self.config.quantization,
dtype=self.config.dtype,
# memory
swap_space=self.config.swap_space,
gpu_memory_utilization=self.config.gpu_memory_utilization,
# cuda graphs
enforce_eager=self.config.enforce_eager,
max_context_len_to_capture=self.config.max_context_len_to_capture,
max_seq_len_to_capture=self.config.max_seq_len_to_capture,
# kernels
disable_custom_all_reduce=self.config.disable_custom_all_reduce,
# additional stuff
trust_remote_code=self.config.model_kwargs.get("trust_remote_code", False),
tokenizer_revision=self.config.processor_kwargs.get("revision", None),
revision=self.config.model_kwargs.get("revision", None),
seed=self.config.seed,
)
if self.config.serving_mode == "offline":
self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.config.to_engine_args()))
else:
self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.config.to_engine_args()))

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
Expand All @@ -134,11 +112,31 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

return inputs

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
def batch_offline_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
for i, prompt in enumerate(inputs["prompts"]):
self.pretrained_model.add_request(
inputs=prompt,
request_id=str(i),
params=SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)

while self.pretrained_model.has_unfinished_requests():
self.pretrained_model.step()

async def single_online_engine_generate(self, prompt: str, request_id: str, kwargs: Dict[str, Any]) -> Any:
stream = await self.pretrained_model.add_request(
inputs=prompt,
request_id=request_id,
params=SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
Expand All @@ -150,33 +148,23 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
),
)

async for _ in stream:
pass

async def batch_online_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
tasks = [
self.single_online_engine_generate(prompt, str(i), kwargs) for i, prompt in enumerate(inputs["prompts"])
]
await asyncio.gather(*tasks)

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Dict[str, Any]:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)
if self.config.serving_mode == "offline":
self.batch_offline_engine_generate(inputs, kwargs)
else:
asyncio.run(self.batch_online_engine_generate(inputs, kwargs))

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
ignore_eos=True,
detokenize=True,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)
if self.config.serving_mode == "offline":
self.batch_offline_engine_generate(inputs, kwargs)
else:
asyncio.run(self.batch_online_engine_generate(inputs, kwargs))
66 changes: 42 additions & 24 deletions optimum_benchmark/backends/vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

from ...import_utils import vllm_version
from ..config import BackendConfig
Expand All @@ -11,36 +11,54 @@ class VLLMConfig(BackendConfig):
version: Optional[str] = vllm_version()
_target_: str = "optimum_benchmark.backends.vllm.backend.VLLMBackend"

# optimum-benchmark
# creates a model from scratch with dummy weights
no_weights: bool = False

# tokenizer
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
# decides whether to use the offline or online llm engine
serving_mode: str = "online"

# parallelism
tensor_parallel_size: int = 1
# passed to EngineArgs
engine_args: Dict[str, Any] = field(default_factory=dict)

# precision
dtype: str = "auto"
quantization: Optional[str] = None
def __post_init__(self):
# duplicates that are handled by the backend config directly
if "model" in self.engine_args:
raise ValueError("model should not be passed in `backend.engine_args`, use `backend.model` instead")

# cuda graphs
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
if "tokenizer" in self.engine_args:
raise ValueError("tokenizer should not be passed in `backend.engine_args`, use `backend.processor` instead")

# kernels
disable_custom_all_reduce: bool = False
if "device" in self.engine_args:
raise ValueError("device should not be passed in `backend.engine_args`, use `backend.device` instead")

# memory
gpu_memory_utilization: float = 0.9
swap_space: int = 4
if self.serving_mode not in ["offline", "online"]:
raise ValueError(f"Invalid serving_mode: {self.serving_mode}. Must be 'online' or 'offline'.")

# needed for task/library/model_type inference
self.model_kwargs = {
"revision": self.engine_args.get("revision", "main"),
"trust_remote_code": self.engine_args.get("trust_remote_code", False),
**self.model_kwargs,
}
self.processor_kwargs = {
"revision": self.engine_args.get("tokenizer_revision", "main"),
"trust_remote_code": self.engine_args.get("trust_remote_code", False),
**self.processor_kwargs,
}

def __post_init__(self):
super().__post_init__()

self.device = self.device.lower()
if self.engine_args.get("disable_log_stats", None) is None:
self.engine_args["disable_log_stats"] = True

if self.serving_mode == "online":
if self.engine_args.get("disable_log_requests", None) is None:
self.engine_args["disable_log_requests"] = True

if self.device not in ["cuda", "neuron", "cpu"]:
raise ValueError(f"VLLM Backend only supports 'cpu', 'cuda' and 'neuron' devices, got {self.device}")
def to_engine_args(self) -> Dict[str, Any]:
return dict(
model=self.model,
tokenizer=self.processor,
device=self.device,
**self.engine_args,
)
5 changes: 5 additions & 0 deletions tests/configs/_serving_mode_.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
hydra:
mode: MULTIRUN
sweeper:
params:
backend.serving_mode: online,offline
1 change: 1 addition & 0 deletions tests/configs/cuda_inference_vllm_bloom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defaults:
- _base_ # inherits from base config
- _cuda_ # inherits from cuda config
- _inference_ # inherits from inference config
- _serving_mode_ # inherits from serving_mode config
- _bloom_ # inherits from bloom config
- _self_ # hydra 1.1 compatibility
- override backend: vllm
Expand Down
11 changes: 11 additions & 0 deletions tests/configs/cuda_inference_vllm_bloom_tp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
# order of inheritance, last one overrides previous ones
- _base_ # inherits from base config
- _cuda_ # inherits from cuda config
- _inference_ # inherits from inference config
- _serving_mode_ # inherits from serving_mode config
- _bloom_ # inherits from bloom config
- _self_ # hydra 1.1 compatibility
- override backend: vllm

name: cuda_inference_vllm_bloom

0 comments on commit 94f9961

Please sign in to comment.