-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d35829e
commit ef4a897
Showing
12 changed files
with
244 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import gc | ||
import os | ||
from logging import getLogger | ||
from tempfile import TemporaryDirectory | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from safetensors.torch import save_file | ||
from vllm import LLM, SamplingParams | ||
|
||
from ...task_utils import TEXT_GENERATION_TASKS | ||
from ..base import Backend | ||
from ..transformers_utils import random_init_weights | ||
from .config import VLLMConfig | ||
|
||
LOGGER = getLogger("vllm") | ||
|
||
|
||
class VLLMBackend(Backend[VLLMConfig]): | ||
NAME: str = "vllm" | ||
|
||
def __init__(self, config: VLLMConfig) -> None: | ||
super().__init__(config) | ||
self.validate_task() | ||
|
||
LOGGER.info("\t+ Creating backend temporary directory") | ||
self.tmpdir = TemporaryDirectory() | ||
|
||
if self.config.no_weights: | ||
LOGGER.info("\t+ Loading no weights model") | ||
self.load_model_with_no_weights() | ||
else: | ||
LOGGER.info("\t+ Loading pretrained model") | ||
self.load_model_from_pretrained() | ||
|
||
self.tmpdir.cleanup() | ||
|
||
def create_no_weights_model(self) -> None: | ||
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") | ||
LOGGER.info("\t+ Creating no weights model directory") | ||
os.makedirs(self.no_weights_model, exist_ok=True) | ||
LOGGER.info("\t+ Creating no weights model state dict") | ||
state_dict = torch.nn.Linear(1, 1).state_dict() | ||
LOGGER.info("\t+ Saving no weights model safetensors") | ||
safetensor = os.path.join(self.no_weights_model, "model.safetensors") | ||
save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"}) | ||
LOGGER.info("\t+ Saving no weights model pretrained config") | ||
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) | ||
LOGGER.info("\t+ Saving no weights model pretrained processor") | ||
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) | ||
# unlike Transformers, vLLM won't accept any missing tensors so we need to materialize the model | ||
LOGGER.info(f"\t+ Loading no weights model from {self.no_weights_model}") | ||
with random_init_weights(): | ||
self.pretrained_model = self.automodel_class.from_pretrained( | ||
self.no_weights_model, **self.config.hub_kwargs, device_map="auto", _fast_init=False | ||
) | ||
LOGGER.info("\t+ Saving full no weights model pretrained model") | ||
self.pretrained_model.save_pretrained(save_directory=self.no_weights_model) | ||
del self.pretrained_model | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
def load_model_with_no_weights(self) -> None: | ||
self.create_no_weights_model() | ||
original_model, self.config.model = self.config.model, self.no_weights_model | ||
LOGGER.info("\t+ Loading no weights model") | ||
self.load_model_from_pretrained() | ||
self.config.model = original_model | ||
|
||
def load_model_from_pretrained(self) -> None: | ||
self.pretrained_model = LLM( | ||
model=self.config.model, | ||
# tokenizer | ||
tokenizer=self.config.processor, | ||
tokenizer_mode=self.config.tokenizer_mode, | ||
skip_tokenizer_init=self.config.skip_tokenizer_init, | ||
# device | ||
device=self.config.device, | ||
# parallelism | ||
tensor_parallel_size=self.config.tensor_parallel_size, | ||
# precision | ||
quantization=self.config.quantization, | ||
dtype=self.config.dtype, | ||
# memory | ||
swap_space=self.config.swap_space, | ||
gpu_memory_utilization=self.config.gpu_memory_utilization, | ||
# cuda graphs | ||
enforce_eager=self.config.enforce_eager, | ||
max_context_len_to_capture=self.config.max_context_len_to_capture, | ||
max_seq_len_to_capture=self.config.max_seq_len_to_capture, | ||
# kernels | ||
disable_custom_all_reduce=self.config.disable_custom_all_reduce, | ||
# additional stuff | ||
trust_remote_code=self.config.model_kwargs.get("trust_remote_code", False), | ||
tokenizer_revision=self.config.processor_kwargs.get("revision", None), | ||
revision=self.config.model_kwargs.get("revision", None), | ||
seed=self.config.seed, | ||
) | ||
|
||
def validate_task(self) -> None: | ||
if self.config.task not in ["text-generation"]: | ||
raise ValueError(f"Task {self.config.task} not supported by {self.NAME}") | ||
|
||
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
if self.config.task in TEXT_GENERATION_TASKS: | ||
inputs = self.pretrained_processor.batch_decode(inputs["input_ids"]) | ||
return {"prompts": inputs} | ||
else: | ||
raise NotImplementedError(f"VLLM does not support task {self.config.task}") | ||
|
||
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: | ||
return self.pretrained_model.generate( | ||
**inputs, | ||
use_tqdm=False, | ||
sampling_params=SamplingParams( | ||
ignore_eos=True, | ||
n=kwargs.get("num_return_sequences"), | ||
max_tokens=kwargs.get("max_new_tokens"), | ||
min_tokens=kwargs.get("min_new_tokens"), | ||
use_beam_search=kwargs.get("num_beams") > 1, | ||
logits_processors=kwargs.get("logits_processors", None), | ||
), | ||
) | ||
|
||
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Dict[str, Any]: | ||
return self.pretrained_model.generate( | ||
**inputs, | ||
use_tqdm=False, | ||
sampling_params=SamplingParams( | ||
ignore_eos=True, | ||
n=kwargs.get("num_return_sequences"), | ||
max_tokens=kwargs.get("max_new_tokens"), | ||
min_tokens=kwargs.get("min_new_tokens"), | ||
use_beam_search=kwargs.get("num_beams") > 1, | ||
logits_processors=kwargs.get("logits_processors", None), | ||
), | ||
) | ||
|
||
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: | ||
return self.pretrained_model.generate( | ||
**inputs, | ||
use_tqdm=False, | ||
sampling_params=SamplingParams( | ||
ignore_eos=True, | ||
n=kwargs.get("num_return_sequences"), | ||
max_tokens=kwargs.get("max_new_tokens"), | ||
min_tokens=kwargs.get("min_new_tokens"), | ||
use_beam_search=kwargs.get("num_beams") > 1, | ||
logits_processors=kwargs.get("logits_processors", None), | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from ...import_utils import vllm_version | ||
from ..config import BackendConfig | ||
|
||
|
||
@dataclass | ||
class VLLMConfig(BackendConfig): | ||
name: str = "vllm" | ||
version: Optional[str] = vllm_version() | ||
_target_: str = "optimum_benchmark.backends.vllm.backend.VLLMBackend" | ||
|
||
# optimum-benchmark | ||
no_weights: bool = False | ||
|
||
# tokenizer | ||
tokenizer_mode: str = "auto" | ||
skip_tokenizer_init: bool = False | ||
|
||
# parallelism | ||
tensor_parallel_size: int = 1 | ||
|
||
# precision | ||
dtype: str = "auto" | ||
quantization: Optional[str] = None | ||
|
||
# cuda graphs | ||
enforce_eager: bool = False | ||
max_context_len_to_capture: Optional[int] = None | ||
max_seq_len_to_capture: int = 8192 | ||
|
||
# kernels | ||
disable_custom_all_reduce: bool = False | ||
|
||
# memory | ||
gpu_memory_utilization: float = 0.9 | ||
swap_space: int = 4 | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
self.device = self.device.lower() | ||
|
||
if self.device not in ["cuda", "neuron", "cpu"]: | ||
raise ValueError(f"VLLM Backend only supports 'cpu', 'cuda' and 'neuron' devices, got {self.device}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.