diff --git a/.github/workflows/update_llm_perf_cpu_pytorch.yaml b/.github/workflows/update_llm_perf_cpu_pytorch.yaml index cdec51c4d..eadacc26a 100644 --- a/.github/workflows/update_llm_perf_cpu_pytorch.yaml +++ b/.github/workflows/update_llm_perf_cpu_pytorch.yaml @@ -4,16 +4,32 @@ on: workflow_dispatch: schedule: - cron: "0 0 * * *" + push: + branches: + - main + pull_request: + branches: + - main + types: + - opened + - reopened + - synchronize + - labeled + - unlabeled concurrency: cancel-in-progress: true - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} env: IMAGE: ghcr.io/huggingface/optimum-benchmark:latest-cpu jobs: run_benchmarks: + if: ${{ + (github.event_name == 'push') || + (github.event_name == 'workflow_dispatch') || + contains( github.event.pull_request.labels.*.name, 'leaderboard')}} strategy: fail-fast: false matrix: @@ -49,4 +65,4 @@ jobs: pip install packaging && pip install einops scipy optimum codecarbon pip install -U transformers huggingface_hub[hf_transfer] pip install -e . - python llm_perf/update_llm_perf_cpu_pytorch.py + python llm_perf/benchmark_runners/update_llm_perf_cpu_pytorch.py diff --git a/.github/workflows/update_llm_perf_cuda_pytorch.yaml b/.github/workflows/update_llm_perf_cuda_pytorch.yaml index 7c902b8c3..01aed81e7 100644 --- a/.github/workflows/update_llm_perf_cuda_pytorch.yaml +++ b/.github/workflows/update_llm_perf_cuda_pytorch.yaml @@ -4,16 +4,33 @@ on: workflow_dispatch: schedule: - cron: "0 0 * * *" + push: + branches: + - main + pull_request: + branches: + - main + types: + - opened + - reopened + - synchronize + - labeled + - unlabeled concurrency: cancel-in-progress: true - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} env: IMAGE: ghcr.io/huggingface/optimum-benchmark:latest-cuda jobs: run_benchmarks: + if: ${{ + (github.event_name == 'push') || + (github.event_name == 'workflow_dispatch') || + contains( github.event.pull_request.labels.*.name, 'leaderboard')}} + strategy: fail-fast: false matrix: @@ -54,4 +71,4 @@ jobs: pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq codecarbon pip install -U transformers huggingface_hub[hf_transfer] pip install -e . - python llm_perf/update_llm_perf_cuda_pytorch.py + python llm_perf/benchmark_runners/update_llm_perf_cuda_pytorch.py diff --git a/.github/workflows/update_llm_perf_leaderboard.yaml b/.github/workflows/update_llm_perf_leaderboard.yaml index 10ed80c98..f0a6b7a43 100644 --- a/.github/workflows/update_llm_perf_leaderboard.yaml +++ b/.github/workflows/update_llm_perf_leaderboard.yaml @@ -4,13 +4,30 @@ on: workflow_dispatch: schedule: - cron: "0 */6 * * *" + push: + branches: + - main + pull_request: + branches: + - main + types: + - opened + - reopened + - synchronize + - labeled + - unlabeled concurrency: cancel-in-progress: true - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} jobs: update_llm_perf_leaderboard: + if: ${{ + (github.event_name == 'push') || + (github.event_name == 'workflow_dispatch') || + contains( github.event.pull_request.labels.*.name, 'leaderboard')}} + runs-on: ubuntu-latest steps: - name: Checkout diff --git a/llm_perf/benchmark_runners/update_llm_perf_cpu_pytorch.py b/llm_perf/benchmark_runners/update_llm_perf_cpu_pytorch.py new file mode 100644 index 000000000..c27f5e220 --- /dev/null +++ b/llm_perf/benchmark_runners/update_llm_perf_cpu_pytorch.py @@ -0,0 +1,92 @@ +from itertools import product +from typing import Any, Dict, List + +from llm_perf.common.benchmark_runner import LLMPerfBenchmarkManager +from llm_perf.common.utils import CANONICAL_PRETRAINED_OPEN_LLM_LIST, GENERATE_KWARGS, INPUT_SHAPES +from optimum_benchmark import PyTorchConfig +from optimum_benchmark.benchmark.config import BenchmarkConfig +from optimum_benchmark.launchers.process.config import ProcessConfig +from optimum_benchmark.scenarios.inference.config import InferenceConfig + + +class CPUPyTorchBenchmarkRunner(LLMPerfBenchmarkManager): + def __init__(self): + super().__init__(backend="pytorch", device="cpu") + + self.attention_configs = self._get_attention_configs() + assert self.subset is not None, "SUBSET environment variable must be set for benchmarking" + self.weights_configs = self._get_weights_configs(self.subset) + + def get_list_of_benchmarks_to_run(self) -> List[Dict[str, Any]]: + return [ + {"model": model, "attn_implementation": attn_impl, "weights_config": weights_cfg} + for model, attn_impl, weights_cfg in product( + CANONICAL_PRETRAINED_OPEN_LLM_LIST, self.attention_configs, self.weights_configs.keys() + ) + ] + + def get_benchmark_name(self, model: str, **kwargs) -> str: + weights_config = kwargs["weights_config"] + attn_implementation = kwargs["attn_implementation"] + return f"{model}-{weights_config}-{attn_implementation}" + + def get_benchmark_config(self, model: str, **kwargs) -> BenchmarkConfig: + weights_config = kwargs["weights_config"] + attn_implementation = kwargs["attn_implementation"] + + assert ( + weights_config in self.weights_configs + ), f"your config does not contain {weights_config}, adjust your _get_weights_configs to fix this issue" + + torch_dtype = self.weights_configs[weights_config]["torch_dtype"] + quant_scheme = self.weights_configs[weights_config]["quant_scheme"] + quant_config = self.weights_configs[weights_config]["quant_config"] + + launcher_config = ProcessConfig() + scenario_config = InferenceConfig( + memory=True, + energy=True, + latency=True, + duration=10, + iterations=10, + warmup_runs=10, + input_shapes=INPUT_SHAPES, + generate_kwargs=GENERATE_KWARGS, + ) + backend_config = PyTorchConfig( + model=model, + device="cpu", + no_weights=True, + library="transformers", + task="text-generation", + torch_dtype=torch_dtype, + quantization_scheme=quant_scheme, + quantization_config=quant_config, + attn_implementation=attn_implementation, + model_kwargs={"trust_remote_code": True}, + ) + + return BenchmarkConfig( + name=f"{weights_config}-{attn_implementation}", + scenario=scenario_config, + launcher=launcher_config, + backend=backend_config, + ) + + def _get_weights_configs(self, subset) -> Dict[str, Dict[str, Any]]: + if subset == "unquantized": + return { + "float32": {"torch_dtype": "float32", "quant_scheme": None, "quant_config": {}}, + "float16": {"torch_dtype": "float16", "quant_scheme": None, "quant_config": {}}, + "bfloat16": {"torch_dtype": "bfloat16", "quant_scheme": None, "quant_config": {}}, + } + else: + raise ValueError(f"Unknown subset: {subset}") + + def _get_attention_configs(self) -> List[str]: + return ["eager", "sdpa"] + + +if __name__ == "__main__": + runner = CPUPyTorchBenchmarkRunner() + runner.run_benchmarks() diff --git a/llm_perf/benchmark_runners/update_llm_perf_cuda_pytorch.py b/llm_perf/benchmark_runners/update_llm_perf_cuda_pytorch.py new file mode 100644 index 000000000..82aab3db9 --- /dev/null +++ b/llm_perf/benchmark_runners/update_llm_perf_cuda_pytorch.py @@ -0,0 +1,147 @@ +from itertools import product +from typing import Any, Dict, List + +from llm_perf.common.benchmark_runner import LLMPerfBenchmarkManager +from llm_perf.common.utils import CANONICAL_PRETRAINED_OPEN_LLM_LIST, GENERATE_KWARGS, INPUT_SHAPES +from optimum_benchmark import PyTorchConfig +from optimum_benchmark.benchmark.config import BenchmarkConfig +from optimum_benchmark.launchers.process.config import ProcessConfig +from optimum_benchmark.scenarios.inference.config import InferenceConfig + + +class CUDAPyTorchBenchmarkRunner(LLMPerfBenchmarkManager): + def __init__(self): + super().__init__(backend="pytorch", device="cuda") + + self.attention_configs = self._get_attention_configs() + assert self.subset is not None, "SUBSET environment variable must be set for benchmarking" + self.weights_configs = self._get_weights_configs(self.subset) + + def get_list_of_benchmarks_to_run(self) -> List[Dict[str, Any]]: + return [ + {"model": model, "attn_implementation": attn_impl, "weights_config": weights_cfg} + for model, attn_impl, weights_cfg in product( + CANONICAL_PRETRAINED_OPEN_LLM_LIST, self.attention_configs, self.weights_configs.keys() + ) + ] + + def get_benchmark_name(self, model: str, **kwargs) -> str: + weights_config = kwargs["weights_config"] + attn_implementation = kwargs["attn_implementation"] + return f"{model}-{weights_config}-{attn_implementation}" + + def is_benchmark_supported(self, **kwargs) -> bool: + if kwargs["attn_implementation"] == "flash_attention_2" and kwargs["weights_config"] == "float32": + return False + return True + + def get_benchmark_config(self, model: str, **kwargs) -> BenchmarkConfig: + weights_config = kwargs["weights_config"] + attn_implementation = kwargs["attn_implementation"] + + assert ( + weights_config in self.weights_configs + ), f"your config does contains the {weights_config}, adjust your _get_weights_configs to fix this issue" + + torch_dtype = self.weights_configs[weights_config]["torch_dtype"] + quant_scheme = self.weights_configs[weights_config]["quant_scheme"] + quant_config = self.weights_configs[weights_config]["quant_config"] + + launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="kill") + scenario_config = InferenceConfig( + memory=True, + energy=True, + latency=True, + duration=10, + iterations=10, + warmup_runs=10, + input_shapes=INPUT_SHAPES, + generate_kwargs=GENERATE_KWARGS, + ) + backend_config = PyTorchConfig( + model=model, + device="cuda", + device_ids="0", + no_weights=True, + library="transformers", + task="text-generation", + torch_dtype=torch_dtype, + quantization_scheme=quant_scheme, + quantization_config=quant_config, + attn_implementation=attn_implementation, + model_kwargs={"trust_remote_code": True}, + ) + + return BenchmarkConfig( + name=f"{weights_config}-{attn_implementation}", + scenario=scenario_config, + launcher=launcher_config, + backend=backend_config, + ) + + def _get_weights_configs(self, subset) -> Dict[str, Dict[str, Any]]: + if subset == "unquantized": + return { + "float32": {"torch_dtype": "float32", "quant_scheme": None, "quant_config": {}}, + "float16": {"torch_dtype": "float16", "quant_scheme": None, "quant_config": {}}, + "bfloat16": {"torch_dtype": "bfloat16", "quant_scheme": None, "quant_config": {}}, + } + elif subset == "bnb": + return { + "4bit-bnb": {"torch_dtype": "float16", "quant_scheme": "bnb", "quant_config": {"load_in_4bit": True}}, + "8bit-bnb": {"torch_dtype": "float16", "quant_scheme": "bnb", "quant_config": {"load_in_8bit": True}}, + } + elif subset == "gptq": + return { + "4bit-gptq-exllama-v1": { + "torch_dtype": "float16", + "quant_scheme": "gptq", + "quant_config": {"bits": 4, "use_exllama ": True, "version": 1, "model_seqlen": 256}, + }, + "4bit-gptq-exllama-v2": { + "torch_dtype": "float16", + "quant_scheme": "gptq", + "quant_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, + }, + } + elif subset == "awq": + return { + "4bit-awq-gemm": { + "torch_dtype": "float16", + "quant_scheme": "awq", + "quant_config": {"bits": 4, "version": "gemm"}, + }, + "4bit-awq-gemv": { + "torch_dtype": "float16", + "quant_scheme": "awq", + "quant_config": {"bits": 4, "version": "gemv"}, + }, + "4bit-awq-exllama-v1": { + "torch_dtype": "float16", + "quant_scheme": "awq", + "quant_config": { + "bits": 4, + "version": "exllama", + "exllama_config": {"version": 1, "max_input_len": 64, "max_batch_size": 1}, + }, + }, + "4bit-awq-exllama-v2": { + "torch_dtype": "float16", + "quant_scheme": "awq", + "quant_config": { + "bits": 4, + "version": "exllama", + "exllama_config": {"version": 2, "max_input_len": 64, "max_batch_size": 1}, + }, + }, + } + else: + raise ValueError(f"Unknown subset: {subset}") + + def _get_attention_configs(self) -> List[str]: + return ["eager", "sdpa", "flash_attention_2"] + + +if __name__ == "__main__": + runner = CUDAPyTorchBenchmarkRunner() + runner.run_benchmarks() diff --git a/llm_perf/common/__init__.py b/llm_perf/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llm_perf/common/benchmark_runner.py b/llm_perf/common/benchmark_runner.py new file mode 100644 index 000000000..def30dc20 --- /dev/null +++ b/llm_perf/common/benchmark_runner.py @@ -0,0 +1,122 @@ +import os +import traceback +from abc import ABC, abstractmethod +from logging import getLogger +from typing import Any, Dict, List, Optional + +from llm_perf.common.utils import ( + CANONICAL_PRETRAINED_OPEN_LLM_LIST, + OPEN_LLM_LIST, + PRETRAINED_OPEN_LLM_LIST, +) +from optimum_benchmark import Benchmark, BenchmarkConfig, BenchmarkReport +from optimum_benchmark.logging_utils import setup_logging + + +class LLMPerfBenchmarkManager(ABC): + def __init__(self, backend: str, device: str, subset: Optional[str] = None, machine: Optional[str] = None): + self.backend = backend + self.device = device + self.subset = subset or os.getenv("SUBSET", None) + self.machine = machine or os.getenv("MACHINE", None) + self.logger = getLogger("llm-perf-backend") + + if self.machine is None and self.subset is None: + self.push_repo_id = f"optimum-benchmark/llm-perf-{self.backend}-{self.device}-debug" + self.canonical_pretrained_open_llm_list = ["gpt2"] + self.subset = "unquantized" + elif self.machine is not None and self.subset is not None: + self.push_repo_id = f"optimum-benchmark/llm-perf-{self.backend}-{self.device}-{self.subset}-{self.machine}" + else: + raise ValueError("Either both MACHINE and SUBSET should be set for benchmarking or neither for debugging") + + self.logger.info(f"len(OPEN_LLM_LIST): {len(OPEN_LLM_LIST)}") + self.logger.info(f"len(PRETRAINED_OPEN_LLM_LIST): {len(PRETRAINED_OPEN_LLM_LIST)}") + self.logger.info(f"len(CANONICAL_PRETRAINED_OPEN_LLM_LIST): {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)}") + + @abstractmethod + def _get_weights_configs(self, subset: str) -> Dict[str, Dict[str, Any]]: + raise NotImplementedError("This method should be implemented in the child class") + + @abstractmethod + def _get_attention_configs(self) -> List[str]: + raise NotImplementedError("This method should be implemented in the child class") + + def is_benchmark_supported(self, **kwargs) -> bool: + """ + Can be overridden by child classes to exclude unsupported configurations + """ + return True + + @abstractmethod + def get_list_of_benchmarks_to_run(self) -> List[Dict[str, Any]]: + raise NotImplementedError("This method should be implemented in the child class") + + def run_benchmarks(self): + os.environ["LOG_TO_FILE"] = "0" + os.environ["LOG_LEVEL"] = "INFO" + setup_logging(level="INFO", prefix="MAIN-PROCESS") + + benchmarks_to_run = self.get_list_of_benchmarks_to_run() + + self.logger.info( + f"Running a total of {len(benchmarks_to_run)} benchmarks, " + f"with {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)} models" + ) + + for benchmark_name in benchmarks_to_run: + assert "model" in benchmark_name, "each benchmark should have a model" + + self.run_benchmark(**benchmark_name) + + def is_benchmark_conducted(self, push_repo_id, subfolder): + try: + report = BenchmarkReport.from_pretrained(repo_id=push_repo_id, subfolder=subfolder) + if "traceback" in report.to_dict(): + return False + else: + return True + except Exception: + return False + + @abstractmethod + def get_benchmark_name(self, model: str, **kwargs) -> str: + raise NotImplementedError("This method should be implemented in the child class") + + def run_benchmark(self, **kwargs): + model = kwargs.pop("model") + + benchmark_name = self.get_benchmark_name(model, **kwargs) + subfolder = f"{benchmark_name}/{model.replace('/', '--')}" + + if not self.is_benchmark_supported(**kwargs): + self.logger.info(f"Skipping benchmark {benchmark_name} with model {model} since it is not supported") + return + + if self.is_benchmark_conducted(self.push_repo_id, subfolder): + self.logger.info(f"Skipping benchmark {benchmark_name} with model {model} since it was already conducted") + return + + benchmark_config = self.get_benchmark_config(model, **kwargs) + benchmark_config.push_to_hub(repo_id=self.push_repo_id, subfolder=subfolder, private=True) + self.execute_and_log_benchmark(benchmark_config, subfolder) + + @abstractmethod + def get_benchmark_config(self, model: str, **kwargs) -> BenchmarkConfig: + raise NotImplementedError("This method should be implemented in the child class") + + def execute_and_log_benchmark(self, benchmark_config: BenchmarkConfig, subfolder: str): + try: + self.logger.info(f"Running benchmark {benchmark_config.name} with model {benchmark_config.backend.model}") + benchmark_report = Benchmark.launch(benchmark_config) + benchmark_report.push_to_hub(repo_id=self.push_repo_id, subfolder=subfolder, private=True) + benchmark = Benchmark(config=benchmark_config, report=benchmark_report) + benchmark.push_to_hub(repo_id=self.push_repo_id, subfolder=subfolder, private=True) + except Exception as e: + self.logger.error( + f"Benchmark {benchmark_config.name} failed with model {benchmark_config.backend.model}, error:\n{e}" + ) + benchmark_report = BenchmarkReport.from_dict({"traceback": traceback.format_exc()}) + benchmark_report.push_to_hub(repo_id=self.push_repo_id, subfolder=subfolder, private=True) + benchmark = Benchmark(config=benchmark_config, report=benchmark_report) + benchmark.push_to_hub(repo_id=self.push_repo_id, subfolder=subfolder, private=True) diff --git a/llm_perf/common/hardware_config.py b/llm_perf/common/hardware_config.py new file mode 100644 index 000000000..ed28222e2 --- /dev/null +++ b/llm_perf/common/hardware_config.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import List + +import yaml + + +@dataclass +class HardwareConfig: + machine: str + hardware: str + subsets: List[str] + backends: List[str] + + def __repr__(self): + return ( + f"HardwareConfig(machine='{self.machine}', hardware='{self.hardware}', " + f"subsets={self.subsets}, backends={self.backends})" + ) + + +def load_hardware_configs(file_path: str) -> List[HardwareConfig]: + with open(file_path, "r") as file: + data = yaml.safe_load(file) + return [HardwareConfig(**config) for config in data] diff --git a/llm_perf/utils.py b/llm_perf/common/utils.py similarity index 90% rename from llm_perf/utils.py rename to llm_perf/common/utils.py index 6a5584284..a78f04664 100644 --- a/llm_perf/utils.py +++ b/llm_perf/common/utils.py @@ -1,7 +1,5 @@ import pandas as pd -from optimum_benchmark.benchmark.report import BenchmarkReport - INPUT_SHAPES = {"batch_size": 1, "sequence_length": 256} GENERATE_KWARGS = {"max_new_tokens": 64, "min_new_tokens": 64} @@ -124,14 +122,3 @@ "togethercomputer/RedPajama-INCITE-Base-3B-v1", "togethercomputer/RedPajama-INCITE-Base-7B-v0.1", ] - - -def is_benchmark_conducted(push_repo_id, subfolder): - try: - report = BenchmarkReport.from_pretrained(repo_id=push_repo_id, subfolder=subfolder) - if "traceback" in report.to_dict(): - return False - else: - return True - except Exception: - return False diff --git a/llm_perf/hardware.yml b/llm_perf/hardware.yml new file mode 100644 index 000000000..1a351b674 --- /dev/null +++ b/llm_perf/hardware.yml @@ -0,0 +1,36 @@ +- machine: 1xA10 + hardware: cuda + subsets: + - unquantized + - awq + - bnb + - gptq + backends: + - pytorch + +- machine: 1xA100 + hardware: cuda + subsets: + - unquantized + - awq + - bnb + - gptq + backends: + - pytorch + +- machine: 1xT4 + hardware: cuda + subsets: + - unquantized + - awq + - bnb + - gptq + backends: + - pytorch + +- machine: 32vCPU-C7i + hardware: cpu + subsets: + - unquantized + backends: + - pytorch \ No newline at end of file diff --git a/llm_perf/update_llm_perf_cpu_pytorch.py b/llm_perf/update_llm_perf_cpu_pytorch.py deleted file mode 100644 index 250355505..000000000 --- a/llm_perf/update_llm_perf_cpu_pytorch.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -import traceback -from itertools import product -from logging import getLogger - -from llm_perf.utils import ( - CANONICAL_PRETRAINED_OPEN_LLM_LIST, - GENERATE_KWARGS, - INPUT_SHAPES, - OPEN_LLM_LIST, - PRETRAINED_OPEN_LLM_LIST, - is_benchmark_conducted, -) -from optimum_benchmark import ( - Benchmark, - BenchmarkConfig, - BenchmarkReport, - InferenceConfig, - ProcessConfig, - PyTorchConfig, -) -from optimum_benchmark.logging_utils import setup_logging - -SUBSET = os.getenv("SUBSET", None) -MACHINE = os.getenv("MACHINE", None) -BACKEND = "pytorch" -HARDWARE = "cpu" - -if os.getenv("MACHINE", None) is None and os.getenv("SUBSET", None) is None: - PUSH_REPO_ID = f"optimum-benchmark/llm-perf-{BACKEND}-{HARDWARE}-debug" - CANONICAL_PRETRAINED_OPEN_LLM_LIST = ["gpt2"] # noqa: F811 - SUBSET = "unquantized" -elif os.getenv("MACHINE", None) is not None and os.getenv("SUBSET", None) is not None: - PUSH_REPO_ID = f"optimum-benchmark/llm-perf-{BACKEND}-{HARDWARE}-{SUBSET}-{MACHINE}" -else: - raise ValueError("Either both MACHINE and SUBSET should be set for benchmarking or neither for debugging") - -ATTENTION_CONFIGS = ["eager", "sdpa"] - - -if SUBSET == "unquantized": - WEIGHTS_CONFIGS = { - # unquantized - "float32": {"torch_dtype": "float32", "quant_scheme": None, "quant_config": {}}, - "float16": {"torch_dtype": "float16", "quant_scheme": None, "quant_config": {}}, - "bfloat16": {"torch_dtype": "bfloat16", "quant_scheme": None, "quant_config": {}}, - } -else: - raise ValueError(f"Subset {SUBSET} not supported") - - -LOGGER = getLogger("llm-perf-backend") -LOGGER.info(f"len(OPEN_LLM_LIST): {len(OPEN_LLM_LIST)}") -LOGGER.info(f"len(PRETRAINED_OPEN_LLM_LIST): {len(PRETRAINED_OPEN_LLM_LIST)}") -LOGGER.info(f"len(CANONICAL_PRETRAINED_OPEN_LLM_LIST): {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)}") - - -def is_benchmark_supported(weights_config, attn_implementation, hardware): - if attn_implementation == "flash_attention_2": - return False - - return True - - -def benchmark_cpu_pytorch(model, attn_implementation, weights_config): - benchmark_name = f"{weights_config}-{attn_implementation}-{BACKEND}" - subfolder = f"{benchmark_name}/{model.replace('/', '--')}" - - torch_dtype = WEIGHTS_CONFIGS[weights_config]["torch_dtype"] - quant_scheme = WEIGHTS_CONFIGS[weights_config]["quant_scheme"] - quant_config = WEIGHTS_CONFIGS[weights_config]["quant_config"] - - if not is_benchmark_supported(weights_config, attn_implementation, HARDWARE): - LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it is not supported") - return - - if is_benchmark_conducted(PUSH_REPO_ID, subfolder): - LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it was already conducted") - return - - launcher_config = ProcessConfig() - scenario_config = InferenceConfig( - memory=True, - energy=True, - latency=True, - duration=10, - iterations=10, - warmup_runs=10, - input_shapes=INPUT_SHAPES, - generate_kwargs=GENERATE_KWARGS, - ) - - backend_config = PyTorchConfig( - model=model, - device="cpu", - no_weights=True, - library="transformers", - task="text-generation", - torch_dtype=torch_dtype, - quantization_scheme=quant_scheme, - quantization_config=quant_config, - attn_implementation=attn_implementation, - model_kwargs={"trust_remote_code": True}, - ) - - benchmark_config = BenchmarkConfig( - name=benchmark_name, scenario=scenario_config, launcher=launcher_config, backend=backend_config - ) - - benchmark_config.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - try: - LOGGER.info(f"Running benchmark {benchmark_name} with model {model}") - benchmark_report = Benchmark.launch(benchmark_config) - benchmark_report.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - benchmark.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - except Exception: - LOGGER.error(f"Benchmark {benchmark_name} failed with model {model}") - benchmark_report = BenchmarkReport.from_dict({"traceback": traceback.format_exc()}) - benchmark_report.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - benchmark.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - -if __name__ == "__main__": - # for isolated process - os.environ["LOG_TO_FILE"] = "0" - os.environ["LOG_LEVEL"] = "INFO" - - # for main process - setup_logging(level="INFO", prefix="MAIN-PROCESS") - - models_attentions_weights = list( - product(CANONICAL_PRETRAINED_OPEN_LLM_LIST, ATTENTION_CONFIGS, WEIGHTS_CONFIGS.keys()) - ) - - LOGGER.info( - f"Running a total of {len(models_attentions_weights)} benchmarks, " - f"with {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)} models, " - f"{len(ATTENTION_CONFIGS)} attentions implementations " - f"and {len(WEIGHTS_CONFIGS)} weights configurations." - ) - - for model, attn_implementation, weights_config in models_attentions_weights: - benchmark_cpu_pytorch(model, attn_implementation, weights_config) diff --git a/llm_perf/update_llm_perf_cuda_pytorch.py b/llm_perf/update_llm_perf_cuda_pytorch.py deleted file mode 100644 index 98914f6ad..000000000 --- a/llm_perf/update_llm_perf_cuda_pytorch.py +++ /dev/null @@ -1,186 +0,0 @@ -import os -import traceback -from itertools import product -from logging import getLogger - -from llm_perf.utils import ( - CANONICAL_PRETRAINED_OPEN_LLM_LIST, - GENERATE_KWARGS, - INPUT_SHAPES, - OPEN_LLM_LIST, - PRETRAINED_OPEN_LLM_LIST, - is_benchmark_conducted, -) -from optimum_benchmark import Benchmark, BenchmarkConfig, BenchmarkReport, InferenceConfig, ProcessConfig, PyTorchConfig -from optimum_benchmark.logging_utils import setup_logging - -SUBSET = os.getenv("SUBSET", None) -MACHINE = os.getenv("MACHINE", None) - -if os.getenv("MACHINE", None) is None and os.getenv("SUBSET", None) is None: - PUSH_REPO_ID = "optimum-benchmark/llm-perf-pytorch-cuda-debug" - CANONICAL_PRETRAINED_OPEN_LLM_LIST = ["gpt2"] # noqa: F811 - SUBSET = "unquantized" -elif os.getenv("MACHINE", None) is not None and os.getenv("SUBSET", None) is not None: - PUSH_REPO_ID = f"optimum-benchmark/llm-perf-pytorch-cuda-{SUBSET}-{MACHINE}" -else: - raise ValueError("Either both MACHINE and SUBSET should be set for benchmarking or neither for debugging") - -ATTENTION_CONFIGS = ["eager", "sdpa", "flash_attention_2"] -if SUBSET == "unquantized": - WEIGHTS_CONFIGS = { - # unquantized - "float32": {"torch_dtype": "float32", "quant_scheme": None, "quant_config": {}}, - "float16": {"torch_dtype": "float16", "quant_scheme": None, "quant_config": {}}, - "bfloat16": {"torch_dtype": "bfloat16", "quant_scheme": None, "quant_config": {}}, - } -elif SUBSET == "bnb": - WEIGHTS_CONFIGS = { - # bnb - "4bit-bnb": {"torch_dtype": "float16", "quant_scheme": "bnb", "quant_config": {"load_in_4bit": True}}, - "8bit-bnb": {"torch_dtype": "float16", "quant_scheme": "bnb", "quant_config": {"load_in_8bit": True}}, - } -elif SUBSET == "gptq": - WEIGHTS_CONFIGS = { - # gptq - "4bit-gptq-exllama-v1": { - "quant_scheme": "gptq", - "torch_dtype": "float16", - "quant_config": {"bits": 4, "use_exllama ": True, "version": 1, "model_seqlen": 256}, - }, - "4bit-gptq-exllama-v2": { - "torch_dtype": "float16", - "quant_scheme": "gptq", - "quant_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, - }, - } -elif SUBSET == "awq": - WEIGHTS_CONFIGS = { - # awq - "4bit-awq-gemm": { - "torch_dtype": "float16", - "quant_scheme": "awq", - "quant_config": {"bits": 4, "version": "gemm"}, - }, - "4bit-awq-gemv": { - "torch_dtype": "float16", - "quant_scheme": "awq", - "quant_config": {"bits": 4, "version": "gemv"}, - }, - "4bit-awq-exllama-v1": { - "torch_dtype": "float16", - "quant_scheme": "awq", - "quant_config": { - "bits": 4, - "version": "exllama", - "exllama_config": {"version": 1, "max_input_len": 64, "max_batch_size": 1}, - }, - }, - "4bit-awq-exllama-v2": { - "torch_dtype": "float16", - "quant_scheme": "awq", - "quant_config": { - "bits": 4, - "version": "exllama", - "exllama_config": {"version": 2, "max_input_len": 64, "max_batch_size": 1}, - }, - }, - } - - -LOGGER = getLogger("llm-perf-backend") -LOGGER.info(f"len(OPEN_LLM_LIST): {len(OPEN_LLM_LIST)}") -LOGGER.info(f"len(PRETRAINED_OPEN_LLM_LIST): {len(PRETRAINED_OPEN_LLM_LIST)}") -LOGGER.info(f"len(CANONICAL_PRETRAINED_OPEN_LLM_LIST): {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)}") - - -def is_benchmark_supported(weights_config, attn_implementation): - if attn_implementation == "flash_attention_2" and weights_config == "float32": - return False - - return True - - -def benchmark_cuda_pytorch(model, attn_implementation, weights_config): - benchmark_name = f"{weights_config}-{attn_implementation}" - subfolder = f"{benchmark_name}/{model.replace('/', '--')}" - - torch_dtype = WEIGHTS_CONFIGS[weights_config]["torch_dtype"] - quant_scheme = WEIGHTS_CONFIGS[weights_config]["quant_scheme"] - quant_config = WEIGHTS_CONFIGS[weights_config]["quant_config"] - - if not is_benchmark_supported(weights_config, attn_implementation): - LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it is not supported") - return - - if is_benchmark_conducted(PUSH_REPO_ID, subfolder): - LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it was already conducted") - return - - launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="kill") - scenario_config = InferenceConfig( - memory=True, - energy=True, - latency=True, - duration=10, - iterations=10, - warmup_runs=10, - input_shapes=INPUT_SHAPES, - generate_kwargs=GENERATE_KWARGS, - ) - backend_config = PyTorchConfig( - model=model, - device="cuda", - device_ids="0", - no_weights=True, - library="transformers", - task="text-generation", - torch_dtype=torch_dtype, - quantization_scheme=quant_scheme, - quantization_config=quant_config, - attn_implementation=attn_implementation, - model_kwargs={"trust_remote_code": True}, - ) - - benchmark_config = BenchmarkConfig( - name=benchmark_name, scenario=scenario_config, launcher=launcher_config, backend=backend_config - ) - - benchmark_config.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - try: - LOGGER.info(f"Running benchmark {benchmark_name} with model {model}") - benchmark_report = Benchmark.launch(benchmark_config) - benchmark_report.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - benchmark.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - except Exception: - LOGGER.error(f"Benchmark {benchmark_name} failed with model {model}") - benchmark_report = BenchmarkReport.from_dict({"traceback": traceback.format_exc()}) - benchmark_report.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - benchmark.push_to_hub(repo_id=PUSH_REPO_ID, subfolder=subfolder, private=True) - - -if __name__ == "__main__": - # for isolated process - os.environ["LOG_TO_FILE"] = "0" - os.environ["LOG_LEVEL"] = "INFO" - - # for main process - setup_logging(level="INFO", prefix="MAIN-PROCESS") - - models_attentions_weights = list( - product(CANONICAL_PRETRAINED_OPEN_LLM_LIST, ATTENTION_CONFIGS, WEIGHTS_CONFIGS.keys()) - ) - - LOGGER.info( - f"Running a total of {len(models_attentions_weights)} benchmarks, " - f"with {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)} models, " - f"{len(ATTENTION_CONFIGS)} attentions implementations " - f"and {len(WEIGHTS_CONFIGS)} weights configurations." - ) - - for model, attn_implementation, weights_config in models_attentions_weights: - benchmark_cuda_pytorch(model, attn_implementation, weights_config) diff --git a/llm_perf/update_llm_perf_leaderboard.py b/llm_perf/update_llm_perf_leaderboard.py index 4516750ae..619e54224 100644 --- a/llm_perf/update_llm_perf_leaderboard.py +++ b/llm_perf/update_llm_perf_leaderboard.py @@ -5,6 +5,7 @@ from huggingface_hub import create_repo, snapshot_download, upload_file from tqdm import tqdm +from llm_perf.common.hardware_config import load_hardware_configs from optimum_benchmark import Benchmark REPO_TYPE = "dataset" @@ -37,16 +38,17 @@ def update_perf_dfs(): """ Update the performance dataframes for all machines """ - for machine in ["1xA10", "1xA100", "1xT4", "32vCPU-C7i"]: - for backend in ["pytorch"]: - for hardware in ["cuda", "cpu"]: - for subset in ["unquantized", "bnb", "awq", "gptq"]: - try: - gather_benchmarks(subset, machine, backend, hardware) - except Exception: - print( - f"benchmark for subset: {subset}, machine: {machine}, backend: {backend}, hardware: {hardware} not found" - ) + hardware_configs = load_hardware_configs("llm_perf/hardware.yml") + + for hardware_config in hardware_configs: + for subset in hardware_config.subsets: + for backend in hardware_config.backends: + try: + gather_benchmarks(subset, hardware_config.machine, backend, hardware_config.hardware) + except Exception: + print( + f"benchmark for subset: {subset}, machine: {hardware_config.machine}, backend: {backend}, hardware: {hardware_config.hardware} not found" + ) scrapping_script = """