Skip to content

Commit

Permalink
examples
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 12, 2024
1 parent 28e6cab commit 3d5ee69
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 52 deletions.
63 changes: 63 additions & 0 deletions examples/pytorch_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging

BENCHMARK_NAME = "pytorch-llama"

WEIGHTS_CONFIGS = {
"float16": {
"torch_dtype": "float16",
"quantization_scheme": None,
"quantization_config": {},
},
# "4bit-awq-gemm": {
# "torch_dtype": "float16",
# "quantization_scheme": "awq",
# "quantization_config": {"bits": 4, "version": "gemm"},
# },
# "4bit-gptq-exllama-v2": {
# "torch_dtype": "float16",
# "quantization_scheme": "gptq",
# "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256},
# },
}


def run_benchmark(weight_config: str):
launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn")
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
no_weights=True,
model="gpt2",
**WEIGHTS_CONFIGS[weight_config],
)
scenario_config = InferenceConfig(
memory=True,
latency=True,
duration=10,
iterations=10,
warmup_runs=10,
input_shapes={"batch_size": 1, "sequence_length": 128},
generate_kwargs={"max_new_tokens": 32, "min_new_tokens": 32},
)

benchmark_config = BenchmarkConfig(
name=BENCHMARK_NAME, launcher=launcher_config, scenario=scenario_config, backend=backend_config
)
benchmark_report = Benchmark.launch(benchmark_config)
benchmark = Benchmark(config=benchmark_config, report=benchmark_report)

filename = f"{BENCHMARK_NAME}-{backend_config.version}-{weight_config}.json"
benchmark.push_to_hub(repo_id="optimum-benchmark/pytorch-llama", filename=filename)
benchmark.save_json(path=f"benchmarks/{filename}")


if __name__ == "__main__":
level = os.environ.get("LOG_LEVEL", "INFO")
to_file = os.environ.get("LOG_TO_FILE", "0") == "1"
setup_logging(level=level, to_file=to_file, prefix="MAIN-PROCESS")

for weight_config in WEIGHTS_CONFIGS:
run_benchmark(weight_config)
23 changes: 15 additions & 8 deletions examples/pytorch_llama_awq.yaml → examples/pytorch_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@ defaults:
- scenario: inference
- launcher: process
- backend: pytorch
- _base_
- _self_

experiment_name: pytorch_llama_awq
name: pytorch_llama

launcher:
device_isolation: true
device_isolation_action: warn

backend:
model: gpt2
device: cuda
device_ids: 0
no_weights: true
model: TheBloke/Llama-2-70B-AWQ
torch_dtype: float16

scenario:
memory: true
latency: true

warmup_runs: 10
iterations: 10
duration: 10

benchmark:
input_shapes:
batch_size: 1
sequence_length: 128
sequence_length: 256
generate_kwargs:
max_new_tokens: 100
min_new_tokens: 100
max_new_tokens: 32
min_new_tokens: 32
28 changes: 0 additions & 28 deletions examples/pytorch_llama_awq.py

This file was deleted.

33 changes: 17 additions & 16 deletions optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
if is_diffusers_available():
import diffusers
from diffusers import DiffusionPipeline
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)

TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {
"inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(),
"text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(),
"image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(),
}

# classes to class names
for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
for model_type, model_class in model_mapping.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__

if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"):
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)

TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {
"inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(),
"text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(),
"image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(),
}

for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
for model_type, model_class in model_mapping.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__

else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}
Expand Down

0 comments on commit 3d5ee69

Please sign in to comment.