Skip to content

Commit

Permalink
added vllm backend
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 10, 2024
1 parent d35829e commit ef4a897
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 18 deletions.
2 changes: 2 additions & 0 deletions optimum_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PyTXIConfig,
TorchORTConfig,
TRTLLMConfig,
VLLMConfig,
)
from .base import Benchmark
from .config import BenchmarkConfig
Expand Down Expand Up @@ -36,6 +37,7 @@
"TrainingConfig",
"TorchORTConfig",
"TRTLLMConfig",
"VLLMConfig",
"TorchrunConfig",
"ExperimentConfig",
"launch",
Expand Down
2 changes: 2 additions & 0 deletions optimum_benchmark/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .pytorch.config import PyTorchConfig
from .tensorrt_llm.config import TRTLLMConfig
from .torch_ort.config import TorchORTConfig
from .vllm.config import VLLMConfig

__all__ = [
"PyTorchConfig",
Expand All @@ -18,4 +19,5 @@
"PyTXIConfig",
"LLMSwarmConfig",
"BackendConfig",
"VLLMConfig",
]
6 changes: 3 additions & 3 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def __init__(self, config: BackendConfigT):
self.generation_config = None

else:
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs)
self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.pretrained_processor = get_transformers_pretrained_processor(
self.config.model, **self.config.hub_kwargs
self.config.processor, **self.config.processor_kwargs
)
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.hub_kwargs)
self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.hub_kwargs)
self.model_shapes = extract_transformers_shapes_from_artifacts(
self.pretrained_config, self.pretrained_processor
)
Expand Down
31 changes: 20 additions & 11 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@

LOGGER = getLogger("backend")

# backends share the same hub kwargs
HUB_KWARGS = {
"revision": "main",
"force_download": False,
"local_files_only": False,
"trust_remote_code": False,
}


@dataclass
class BackendConfig(ABC):
Expand All @@ -27,9 +19,11 @@ class BackendConfig(ABC):
_target_: str

task: Optional[str] = None
model: Optional[str] = None
library: Optional[str] = None

model: Optional[str] = None
processor: Optional[str] = None

device: Optional[str] = None
device_ids: Optional[str] = None
# yes we use a string here instead of a list
Expand All @@ -40,12 +34,29 @@ class BackendConfig(ABC):
inter_op_num_threads: Optional[int] = None
intra_op_num_threads: Optional[int] = None

# model kwargs that are added to its init method/constructor
model_kwargs: Dict[str, Any] = field(default_factory=dict)
# processor kwargs that are added to its init method/constructor
processor_kwargs: Dict[str, Any] = field(default_factory=dict)

# deprecated
hub_kwargs: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if self.model is None:
raise ValueError("`model` must be specified.")

if self.processor is None:
self.processor = self.model

if self.hub_kwargs:
LOGGER.warning(
"`hub_kwargs` is deprecated and will be removed in future versions."
"Please use `model_kwargs` and `processor_kwargs` seperately."
)
self.model_kwargs = {**self.model_kwargs, **self.hub_kwargs}
self.processor_kwargs = {**self.processor_kwargs, **self.hub_kwargs}

if self.task is None:
self.task = infer_task_from_model_name_or_path(self.model, self.hub_kwargs.get("revision", None))

Expand Down Expand Up @@ -90,7 +101,5 @@ def __post_init__(self):
if self.intra_op_num_threads == -1:
self.intra_op_num_threads = cpu_count()

self.hub_kwargs = {**HUB_KWARGS, **self.hub_kwargs}


BackendConfigT = TypeVar("BackendConfigT", bound=BackendConfig)
6 changes: 4 additions & 2 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class PyTXIBackend(Backend[PyTXIConfig]):
def __init__(self, config: PyTXIConfig) -> None:
super().__init__(config)

self.volume = list(self.config.volumes.keys())[0]

LOGGER.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

Expand All @@ -44,6 +42,10 @@ def __init__(self, config: PyTXIConfig) -> None:

self.tmpdir.cleanup()

@property
def volume(self) -> str:
return list(self.config.volumes.keys())[0]

def download_pretrained_model(self) -> None:
# directly downloads pretrained model in volume (/data) to change generation config before loading model
with torch.device("meta"):
Expand Down
4 changes: 3 additions & 1 deletion optimum_benchmark/backends/py_txi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from logging import getLogger
from typing import Any, Dict, List, Optional, Union

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

from ...import_utils import py_txi_version
from ...system_utils import is_nvidia_system, is_rocm_system
from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS
Expand Down Expand Up @@ -34,7 +36,7 @@ class PyTXIConfig(BackendConfig):
metadata={"help": "Dictionary of ports to expose from the container."},
)
volumes: Dict[str, Any] = field(
default_factory=lambda: {os.path.expanduser("~/.cache/huggingface/hub"): {"bind": "/data", "mode": "rw"}},
default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}},
metadata={"help": "Dictionary of volumes to mount inside the container."},
)
environment: List[str] = field(
Expand Down
Empty file.
151 changes: 151 additions & 0 deletions optimum_benchmark/backends/vllm/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import gc
import os
from logging import getLogger
from tempfile import TemporaryDirectory
from typing import Any, Dict

import torch
from safetensors.torch import save_file
from vllm import LLM, SamplingParams

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import random_init_weights
from .config import VLLMConfig

LOGGER = getLogger("vllm")


class VLLMBackend(Backend[VLLMConfig]):
NAME: str = "vllm"

def __init__(self, config: VLLMConfig) -> None:
super().__init__(config)
self.validate_task()

LOGGER.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.no_weights:
LOGGER.info("\t+ Loading no weights model")
self.load_model_with_no_weights()
else:
LOGGER.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

self.tmpdir.cleanup()

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
LOGGER.info("\t+ Creating no weights model directory")
os.makedirs(self.no_weights_model, exist_ok=True)
LOGGER.info("\t+ Creating no weights model state dict")
state_dict = torch.nn.Linear(1, 1).state_dict()
LOGGER.info("\t+ Saving no weights model safetensors")
safetensor = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"})
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
LOGGER.info("\t+ Saving no weights model pretrained processor")
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)
# unlike Transformers, vLLM won't accept any missing tensors so we need to materialize the model
LOGGER.info(f"\t+ Loading no weights model from {self.no_weights_model}")
with random_init_weights():
self.pretrained_model = self.automodel_class.from_pretrained(
self.no_weights_model, **self.config.hub_kwargs, device_map="auto", _fast_init=False
)
LOGGER.info("\t+ Saving full no weights model pretrained model")
self.pretrained_model.save_pretrained(save_directory=self.no_weights_model)
del self.pretrained_model
torch.cuda.empty_cache()
gc.collect()

def load_model_with_no_weights(self) -> None:
self.create_no_weights_model()
original_model, self.config.model = self.config.model, self.no_weights_model
LOGGER.info("\t+ Loading no weights model")
self.load_model_from_pretrained()
self.config.model = original_model

def load_model_from_pretrained(self) -> None:
self.pretrained_model = LLM(
model=self.config.model,
# tokenizer
tokenizer=self.config.processor,
tokenizer_mode=self.config.tokenizer_mode,
skip_tokenizer_init=self.config.skip_tokenizer_init,
# device
device=self.config.device,
# parallelism
tensor_parallel_size=self.config.tensor_parallel_size,
# precision
quantization=self.config.quantization,
dtype=self.config.dtype,
# memory
swap_space=self.config.swap_space,
gpu_memory_utilization=self.config.gpu_memory_utilization,
# cuda graphs
enforce_eager=self.config.enforce_eager,
max_context_len_to_capture=self.config.max_context_len_to_capture,
max_seq_len_to_capture=self.config.max_seq_len_to_capture,
# kernels
disable_custom_all_reduce=self.config.disable_custom_all_reduce,
# additional stuff
trust_remote_code=self.config.model_kwargs.get("trust_remote_code", False),
tokenizer_revision=self.config.processor_kwargs.get("revision", None),
revision=self.config.model_kwargs.get("revision", None),
seed=self.config.seed,
)

def validate_task(self) -> None:
if self.config.task not in ["text-generation"]:
raise ValueError(f"Task {self.config.task} not supported by {self.NAME}")

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
inputs = self.pretrained_processor.batch_decode(inputs["input_ids"])
return {"prompts": inputs}
else:
raise NotImplementedError(f"VLLM does not support task {self.config.task}")

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
ignore_eos=True,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Dict[str, Any]:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
ignore_eos=True,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
return self.pretrained_model.generate(
**inputs,
use_tqdm=False,
sampling_params=SamplingParams(
ignore_eos=True,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
)
46 changes: 46 additions & 0 deletions optimum_benchmark/backends/vllm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Optional

from ...import_utils import vllm_version
from ..config import BackendConfig


@dataclass
class VLLMConfig(BackendConfig):
name: str = "vllm"
version: Optional[str] = vllm_version()
_target_: str = "optimum_benchmark.backends.vllm.backend.VLLMBackend"

# optimum-benchmark
no_weights: bool = False

# tokenizer
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False

# parallelism
tensor_parallel_size: int = 1

# precision
dtype: str = "auto"
quantization: Optional[str] = None

# cuda graphs
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192

# kernels
disable_custom_all_reduce: bool = False

# memory
gpu_memory_utilization: float = 0.9
swap_space: int = 4

def __post_init__(self):
super().__post_init__()

self.device = self.device.lower()

if self.device not in ["cuda", "neuron", "cpu"]:
raise ValueError(f"VLLM Backend only supports 'cpu', 'cuda' and 'neuron' devices, got {self.device}")
3 changes: 2 additions & 1 deletion optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TorchrunConfig,
TrainingConfig,
TRTLLMConfig,
VLLMConfig,
launch,
)
from .logging_utils import setup_logging
Expand Down Expand Up @@ -75,6 +76,7 @@ def __post_init__(self):
cs.store(group="backend", name=INCConfig.name, node=INCConfig)
cs.store(group="backend", name=PyTXIConfig.name, node=PyTXIConfig)
cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig)
cs.store(group="backend", name=VLLMConfig.name, node=VLLMConfig)
# scenarios configurations
cs.store(group="scenario", name=TrainingConfig.name, node=TrainingConfig)
cs.store(group="scenario", name=InferenceConfig.name, node=InferenceConfig)
Expand All @@ -83,7 +85,6 @@ def __post_init__(self):
cs.store(group="launcher", name=InlineConfig.name, node=InlineConfig)
cs.store(group="launcher", name=ProcessConfig.name, node=ProcessConfig)
cs.store(group="launcher", name=TorchrunConfig.name, node=TorchrunConfig)

# deprecated
cs.store(name="experiment", node=ExperimentConfig)
cs.store(group="benchmark", name=TrainingConfig.name, node=DeprecatedTrainingConfig)
Expand Down
10 changes: 10 additions & 0 deletions optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
_pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None
_llm_swarm_available = importlib.util.find_spec("llm_swarm") is not None
_zentorch_available = importlib.util.find_spec("zentorch") is not None
_vllm_available = importlib.util.find_spec("vllm") is not None


def is_vllm_available():
return _vllm_available


def is_zentorch_available():
Expand Down Expand Up @@ -213,6 +218,11 @@ def llm_swarm_version():
return importlib.metadata.version("llm_swarm")


def vllm_version():
if _vllm_available:
return importlib.metadata.version("vllm")


def get_git_revision_hash(package_name: str) -> Optional[str]:
"""
Returns the git commit SHA of a package installed from a git repository.
Expand Down
Loading

0 comments on commit ef4a897

Please sign in to comment.