Skip to content

Commit

Permalink
Refactor backends and add load tracking (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jul 15, 2024
1 parent 7999050 commit e291e9b
Show file tree
Hide file tree
Showing 30 changed files with 940 additions and 744 deletions.
63 changes: 63 additions & 0 deletions examples/pytorch_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging

BENCHMARK_NAME = "pytorch-llama"

WEIGHTS_CONFIGS = {
"float16": {
"torch_dtype": "float16",
"quantization_scheme": None,
"quantization_config": {},
},
# "4bit-awq-gemm": {
# "torch_dtype": "float16",
# "quantization_scheme": "awq",
# "quantization_config": {"bits": 4, "version": "gemm"},
# },
# "4bit-gptq-exllama-v2": {
# "torch_dtype": "float16",
# "quantization_scheme": "gptq",
# "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256},
# },
}


def run_benchmark(weight_config: str):
launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn")
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
no_weights=True,
model="gpt2",
**WEIGHTS_CONFIGS[weight_config],
)
scenario_config = InferenceConfig(
memory=True,
latency=True,
duration=10,
iterations=10,
warmup_runs=10,
input_shapes={"batch_size": 1, "sequence_length": 128},
generate_kwargs={"max_new_tokens": 32, "min_new_tokens": 32},
)

benchmark_config = BenchmarkConfig(
name=BENCHMARK_NAME, launcher=launcher_config, scenario=scenario_config, backend=backend_config
)
benchmark_report = Benchmark.launch(benchmark_config)
benchmark = Benchmark(config=benchmark_config, report=benchmark_report)

filename = f"{BENCHMARK_NAME}-{backend_config.version}-{weight_config}.json"
benchmark.push_to_hub(repo_id="optimum-benchmark/pytorch-llama", filename=filename)
benchmark.save_json(path=f"benchmarks/{filename}")


if __name__ == "__main__":
level = os.environ.get("LOG_LEVEL", "INFO")
to_file = os.environ.get("LOG_TO_FILE", "0") == "1"
setup_logging(level=level, to_file=to_file, prefix="MAIN-PROCESS")

for weight_config in WEIGHTS_CONFIGS:
run_benchmark(weight_config)
23 changes: 15 additions & 8 deletions examples/pytorch_llama_awq.yaml → examples/pytorch_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@ defaults:
- scenario: inference
- launcher: process
- backend: pytorch
- _base_
- _self_

experiment_name: pytorch_llama_awq
name: pytorch_llama

launcher:
device_isolation: true
device_isolation_action: warn

backend:
model: gpt2
device: cuda
device_ids: 0
no_weights: true
model: TheBloke/Llama-2-70B-AWQ
torch_dtype: float16

scenario:
memory: true
latency: true

warmup_runs: 10
iterations: 10
duration: 10

benchmark:
input_shapes:
batch_size: 1
sequence_length: 128
sequence_length: 256
generate_kwargs:
max_new_tokens: 100
min_new_tokens: 100
max_new_tokens: 32
min_new_tokens: 32
28 changes: 0 additions & 28 deletions examples/pytorch_llama_awq.py

This file was deleted.

65 changes: 43 additions & 22 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import os
from abc import ABC
from collections import OrderedDict
from logging import getLogger
from typing import Any, ClassVar, Dict, Generic, Optional, Tuple
from typing import Any, ClassVar, Dict, Generic, Optional

import datasets.utils.logging as datasets_logging
import transformers.utils.logging as transformers_logging
from safetensors.torch import save_file
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel, TrainerState, set_seed

from ..task_utils import get_automodel_class_for_task
from ..import_utils import is_torch_available
from .config import BackendConfigT
from .diffusers_utils import extract_diffusers_shapes_from_model, get_diffusers_pretrained_config
from .timm_utils import extract_timm_shapes_from_config, get_timm_pretrained_config
from .diffusers_utils import (
extract_diffusers_shapes_from_model,
get_diffusers_automodel_loader_for_task,
get_diffusers_pretrained_config,
)
from .timm_utils import extract_timm_shapes_from_config, get_timm_automodel_loader, get_timm_pretrained_config
from .transformers_utils import (
PretrainedProcessor,
extract_transformers_shapes_from_artifacts,
get_transformers_automodel_loader_for_task,
get_transformers_generation_config,
get_transformers_pretrained_config,
get_transformers_pretrained_processor,
)

if is_torch_available():
import torch

datasets_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()

Expand Down Expand Up @@ -47,15 +57,15 @@ def __init__(self, config: BackendConfigT):
self.logger.info("\t+ Benchmarking a Diffusers pipeline")
self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.model_shapes = extract_diffusers_shapes_from_model(self.config.model, **self.config.model_kwargs)
self.model_type = self.config.task
self.automodel_loader = get_diffusers_automodel_loader_for_task(self.config.task)
self.pretrained_processor = None
self.generation_config = None

elif self.config.library == "timm":
self.logger.info("\t+ Benchmarking a Timm model")
self.pretrained_config = get_timm_pretrained_config(self.config.model)
self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config)
self.model_type = self.pretrained_config.architecture
self.automodel_loader = get_timm_automodel_loader()
self.pretrained_processor = None
self.generation_config = None

Expand All @@ -69,31 +79,42 @@ def __init__(self, config: BackendConfigT):
self.model_shapes = extract_transformers_shapes_from_artifacts(
self.pretrained_config, self.pretrained_processor
)
self.model_type = self.pretrained_config.model_type

self.automodel_class = get_automodel_class_for_task(
model_type=self.model_type, library=self.config.library, task=self.config.task, framework="pt"
)
self.logger.info(f"\t+ Using automodel class {self.automodel_class.__name__}")
self.automodel_loader = get_transformers_automodel_loader_for_task(self.config.task)

def seed(self) -> None:
set_seed(self.config.seed)

def prepare_for_inference(self, **kwargs) -> None:
def create_no_weights_model(self) -> None:
if self.pretrained_config is None:
raise ValueError("Can't create no weights model without a pretrained config")

self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
self.logger.info("\t+ Creating no weights model's directory")
os.makedirs(self.no_weights_model, exist_ok=True)
self.logger.info("\t+ Creating no weights model's state dict")
state_dict = torch.nn.Linear(1, 1).state_dict()
self.logger.info("\t+ Saving no weights model's safetensors")
safetensors = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors, metadata={"format": "pt"})
self.logger.info("\t+ Saving no weights model's config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare the model for inference.
It can be used to compile the model with certain input/output shapes, for example.
This method is used to prepare and register the input shapes before using them by the model.
It can be used to pad the inputs to the correct shape, or compile it to the correct format.
"""
pass
return input_shapes

def prepare_inputs(
self, inputs: Dict[str, Any], input_shapes: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare the inputs before passing them to the model.
It can be used to move the inputs to the correct device, for example.
This method is used to prepare and register the inputs before passing them to the model.
It can be used to move the inputs to the correct device, or rename their keys.
"""
return inputs, input_shapes
return inputs

def load(self) -> None:
raise NotImplementedError("Backend must implement load method")

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
"""
Expand Down
12 changes: 11 additions & 1 deletion optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from psutil import cpu_count

from ..system_utils import get_gpu_device_ids, is_nvidia_system, is_rocm_system
from ..task_utils import infer_library_from_model_name_or_path, infer_task_from_model_name_or_path
from ..task_utils import (
infer_library_from_model_name_or_path,
infer_model_type_from_model_name_or_path,
infer_task_from_model_name_or_path,
)

LOGGER = getLogger("backend")

Expand All @@ -20,6 +24,7 @@ class BackendConfig(ABC):

task: Optional[str] = None
library: Optional[str] = None
model_type: Optional[str] = None

model: Optional[str] = None
processor: Optional[str] = None
Expand Down Expand Up @@ -63,6 +68,11 @@ def __post_init__(self):
if self.library is None:
self.library = infer_library_from_model_name_or_path(self.model, self.hub_kwargs.get("revision", None))

if self.model_type is None:
self.model_type = infer_model_type_from_model_name_or_path(
self.model, self.hub_kwargs.get("revision", None)
)

if self.device is None:
self.device = "cuda" if is_nvidia_system() or is_rocm_system() else "cpu"

Expand Down
39 changes: 37 additions & 2 deletions optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,40 @@
from ..import_utils import is_diffusers_available

if is_diffusers_available():
import diffusers # type: ignore
import diffusers
from diffusers import DiffusionPipeline

if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"):
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)

TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {
"inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(),
"text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(),
"image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(),
}

for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
for model_type, model_class in model_mapping.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}


TASKS_TO_MODEL_LOADERS = {
"inpainting": "AutoPipelineForInpainting",
"text-to-image": "AutoPipelineForText2Image",
"image-to-image": "AutoPipelineForImage2Image",
}


def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]:
return diffusers.DiffusionPipeline.load_config(model, **kwargs)
return DiffusionPipeline.load_config(model, **kwargs)


def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]:
Expand Down Expand Up @@ -38,3 +67,9 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]:
shapes["width"] = -1

return shapes


def get_diffusers_automodel_loader_for_task(task: str):
model_loader_name = TASKS_TO_MODEL_LOADERS[task]
model_loader_class = getattr(diffusers, model_loader_name)
return model_loader_class
21 changes: 8 additions & 13 deletions optimum_benchmark/backends/llm_swarm/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

import torch
from huggingface_hub import AsyncInferenceClient
Expand All @@ -16,19 +16,18 @@ class LLMSwarmBackend(Backend[LLMSwarmConfig]):

def __init__(self, config: LLMSwarmConfig) -> None:
super().__init__(config)
self.validate_task()

if self.config.task not in TEXT_GENERATION_TASKS:
raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}")

def load(self) -> None:
self.logger.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
self.logger.info("\t+ Preparing generation config")
self.prepare_generation_config()
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

def validate_task(self) -> None:
if self.config.task not in TEXT_GENERATION_TASKS:
raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}")

def load_model_from_pretrained(self) -> None:
self.llm_swarm_config = LLMSwarmCfg(
gpus=self.config.gpus,
Expand All @@ -46,7 +45,7 @@ def load_model_from_pretrained(self) -> None:

def download_pretrained_model(self) -> None:
with torch.device("meta"):
self.automodel_class.from_pretrained(self.config.model, **self.config.model_kwargs)
self.auto_model_loader.from_pretrained(self.config.model, **self.config.model_kwargs)

def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = -100
Expand All @@ -60,19 +59,15 @@ def prepare_generation_config(self) -> None:
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def prepare_inputs(
self, inputs: Dict[str, Any], input_shapes: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, input_shapes = super().prepare_inputs(inputs, input_shapes)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if "inputs" in inputs:
inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())}
elif "input_ids" in inputs:
inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
else:
raise ValueError("inputs must contain either input_ids or inputs")

return inputs, input_shapes
return inputs

async def single_client_call(self, prompt: str, kwargs: Dict[str, Any]) -> str:
return await self.client.text_generation(prompt, max_new_tokens=kwargs.get("max_new_tokens", 1))
Expand Down
Loading

0 comments on commit e291e9b

Please sign in to comment.