Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM Mistral, Memory support, qualitative comparision and improvements #172

Merged
merged 8 commits into from
Apr 22, 2024
Merged
60 changes: 48 additions & 12 deletions bench_vllm/README.md

Large diffs are not rendered by default.

211 changes: 106 additions & 105 deletions bench_vllm/bench.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,98 @@
import argparse
import gc
import logging
import os
import sys
import time
from collections import defaultdict

import numpy as np
import torch
from vllm import LLM
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.model_executor.parallel_utils import parallel_state

logging.getLogger("vllm").setLevel(logging.ERROR)
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
sys.path.append(os.getcwd())

from common.base import BaseBenchmarkClass # noqa
from common.utils import launch_cli, make_report # noqa

class LlamaVLLMBenchmark:
def __init__(self, model_path: str, device: str, precision: str):
# VLLM is not supported for CPU issue: https://github.com/vllm-project/vllm/issues/176
# VLLM also not supports Metal, issue: https://github.com/vllm-project/vllm/issues/1441

assert device == "cuda", ValueError("Supported device is cuda only.")
assert precision in ["fp16", "fp32", "int4"], ValueError(
"supported precision are: fp16, fp32 and int4"
class VLLMBenchmark(BaseBenchmarkClass):
def __init__(
self,
model_path: str,
model_name: str,
benchmark_name: str,
precision: str,
device: str,
experiment_name: str,
) -> None:
assert device == "cuda", ValueError("Only supported device is 'cuda'")
assert precision in ["float16", "float32", "int4"], ValueError(
"supported precision are: 'float16', 'float32' and 'int4'"
)

self.model_path, self.precision, self.device = model_path, precision, device
self.results = []
self.precision_map = {"fp16": "float16", "fp32": "float32"}
super().__init__(
model_name=model_name,
model_path=model_path,
benchmark_name=benchmark_name,
experiment_name=experiment_name,
precision=precision,
device=device,
)

def load_model(self):
if self.precision != "int4":
self.model = LLM(model=self.model_path)
self.model.dtype = self.precision_map[precision]
if model_name == "llama":
self.tokenizer_folder = os.path.join(
os.getcwd(), "models", "llama-2-7b-chat-hf"
)
else:
self.tokenizer_folder = os.path.join(
os.getcwd(), "models", "mistral-7b-v0.1-instruct-hf"
)

def load_model_and_tokenizer(self):
if self.precision == "int4":
self.model = LLM(
model=self.model_path, quantization="AWQ", tensor_parallel_size=1
)
else:
self.model = LLM(model=self.model_path)
self.model.dtype = self.precision
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_folder)
return self

def run_model(self, prompt: str, max_tokens: int) -> float:
self.model.max_num_seqs = max_tokens
start = time.time()
output = self.model.generate(prompts=[prompt])
delta = time.time() - start
return len(output[0].outputs[0].token_ids) / delta

def benchmark(
self,
prompt: str,
max_tokens: int,
repetitions: int,
) -> None:
for i in range(repetitions):
logging.info(
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]"
def preprocess(
self, prompt: str, chat_mode: bool = True, for_benchmarks: bool = True
):
if chat_mode:
template = self.get_chat_template_with_instruction(
prompt=prompt, for_benchmarks=for_benchmarks
)
tokens_per_second = self.run_model(prompt, max_tokens)
self.results.append(tokens_per_second)
prompt = self.tokenizer.apply_chat_template(template, tokenize=False)

tokenized_input = self.tokenizer.encode(text=prompt)
return {
"prompt": prompt,
"input_tokens": tokenized_input,
"tensor": None,
"num_input_tokens": len(tokenized_input),
}

def run_model(self, inputs: dict, max_tokens: int, temperature: float) -> dict:
prompt = [inputs["prompt"]]

sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature)
output = self.model.generate(prompt, sampling_params)

generated_text = output[0].outputs[0].text
generated_tokens = output[0].outputs[0].token_ids

return {
"output_tokens": generated_tokens,
"num_output_tokens": len(generated_tokens),
"output_prompt": generated_text,
}

def postprocess(self, output: dict) -> str:
return output["output_prompt"]

def on_exit(self):
if self.device == "cuda":
parallel_state.destroy_model_parallel()
del self.model
Expand All @@ -74,67 +105,37 @@ def benchmark(


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="vllm Benchmark.")
parser.add_argument(
"--prompt",
type=str,
help="The prompt for the model.",
)
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.")
parser.add_argument(
"--repetitions",
type=int,
help="The number of repetitions for the benchmark.",
)
parser.add_argument(
"--device",
help="Device to use for the benchmark.",
)
parser.add_argument(
"--log_file",
type=str,
help="Path to the log file for writing logs (in append mode).",
)
parser.add_argument(
"--models_dir",
type=str,
help="Path to the models directory.",
)
parser = launch_cli(description="vLLM Benchmark.")
args = parser.parse_args()
logging.info(
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} "
+ f"repetitions={args.repetitions} device={args.device}"
)
report = defaultdict(lambda: defaultdict(float))

for precision in ("fp32", "fp16", "int4"):
logging.info(f"Running VLLM benchmark on Llama on {precision} precision.")

llama_vllm_bench = LlamaVLLMBenchmark(
f"{args.models_dir}/llama-2-7b-hf"
if precision != "int4"
else f"{args.models_dir}/llama-2-7b-autoawq",
device=args.device,
precision=precision,
).load_model()

llama_vllm_bench.benchmark(
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions
)

report["llama_vllm"][precision] = {
"mean": np.mean(llama_vllm_bench.results),
"std": np.std(llama_vllm_bench.results),
}
model_folder = os.path.join(os.getcwd(), "models")
model_name = (
f"{args.model_name}-2-7b-chat-"
if args.model_name == "llama"
else f"{args.model_name}-7b-v0.1-instruct-"
)

logging.info("Benchmark report")
with open(args.log_file, "a") as file:
for framework, quantizations in report.items():
for quantization, stats in quantizations.items():
logging.info(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}"
)
print(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}",
file=file,
)
runner_dict = {
"cuda": [
{
"precision": "float32",
"model_path": os.path.join(model_folder, model_name + "hf"),
},
{
"precision": "float16",
"model_path": os.path.join(model_folder, model_name + "hf"),
},
{
"precision": "int4",
"model_path": os.path.join(model_folder, model_name + "autoawq"),
},
]
}

make_report(
args=args,
benchmark_class=VLLMBenchmark,
runner_dict=runner_dict,
benchmark_name="vLLM",
is_bench_pytorch=False,
)
68 changes: 24 additions & 44 deletions bench_vllm/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,35 @@

########################################################################################################
# Script: bench.sh
# Description: This script runs benchmarks VLLM Llama2 benchmark.
# Description: This script runs benchmarks vLLM benchmark.
#
# Usage: ./bench.sh [OPTIONS]
# OPTIONS:
# -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')
# -r, --repetitions Number of repetitions for benchmarks (default: 10)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)
# -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')
# -lf, --log_file Logging file name.
# -md, --models_dir Models directory.
# -h, --help Show this help message
# -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')
# -r, --repetitions Number of repetitions for benchmarks (default: 10)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)
# -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')
# -n, --model_name The name of the model to benchmark (possible values: 'llama' for using Llama2, 'mistral' for using Mistral 7B v0.1)
# -lf, --log_file Logging file name.
# -h, --help Show this help message
########################################################################################################

set -euo pipefail

CURRENT_DIR="$(pwd)"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "OPTIONS:"
echo " -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')"
echo " -r, --repetitions Number of repetitions for benchmarks (default: 10)"
echo " -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)"
echo " -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')"
echo " -n, --model_name The name of the model to benchmark (possible values: 'llama' for using Llama2, 'mistral' for using Mistral 7B v0.1)"
echo " -lf, --log_file Logging file name."
echo " -md, --models_dir Models directory."
echo " -h, --help Show this help message"
exit 1
}

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

check_cuda() {
if command -v nvcc &> /dev/null
Expand Down Expand Up @@ -71,38 +69,26 @@ check_python() {


setup() {

# Check if Logs folder exists else Make the logs folder
LOGS_FOLDER="$CURRENT_DIR/Logs"

if [ -d "$LOGS_FOLDER" ]; then
echo "Folder '$LOGS_FOLDER' already exists. Skipping."
else
# Create the folder
mkdir "$LOGS_FOLDER"
echo "'$LOGS_FOLDER' created."
fi

local DEVICE="$1"
local MODEL_NAME="${2:-llama}"
echo -e "\nSetting up with $SCRIPT_DIR/setup.sh..."
bash "$SCRIPT_DIR"/setup.sh "$1"
bash "$SCRIPT_DIR/setup.sh" "$DEVICE" "$MODEL_NAME"
}

run_benchmarks() {
local PROMPT="$1"
local REPETITIONS="$2"
local MAX_TOKENS="$3"
local DEVICE="$4"
local LOG_FILENAME="$5"
local MODELS_DIR="$6"
local MODEL_NAME="$5"

# shellcheck disable=SC1091
source "$SCRIPT_DIR/venv/bin/activate"
"$PYTHON_CMD" "$SCRIPT_DIR"/bench.py \
--prompt "$PROMPT" \
--repetitions "$REPETITIONS" \
--max_tokens "$MAX_TOKENS" \
--log_file "$LOG_FILENAME" \
--models_dir "$MODELS_DIR" \
--model_name "$MODEL_NAME" \
--device "$DEVICE"
}

Expand All @@ -127,25 +113,20 @@ while [ "$#" -gt 0 ]; do
"cuda" | "metal" | "cpu")
;;
*)
echo "Invalid value for --device. Please use 'cuda', 'metal' or 'cpu'."
echo "Invalid value for --device. Please use 'cuda', 'cpu' or 'metal'."
print_usage
;;
esac
if [ "$DEVICE" == "metal" ] || [ "$DEVICE" == "cpu" ]; then
echo "$DEVICE not supported"
exit 1
fi
if [ "$DEVICE" == "cuda" ]; then
check_cuda
else
echo "Not supported for $DEVICE"
exit 1
fi
shift 2
;;
-lf|--log_file)
LOG_FILENAME="$2"
shift 2
;;
-md|--models_dir)
MODELS_DIR="$2"
-n|--model_name)
MODEL_NAME="$2"
shift 2
;;
-h|--help)
Expand All @@ -160,14 +141,13 @@ done

check_platform
check_python
setup "$DEVICE"

# Set default values if not provided
PROMPT="${PROMPT:-"Write an essay about the transformer model architecture"}"
REPETITIONS="${REPETITIONS:-10}"
MAX_TOKENS="${MAX_TOKENS:-512}"
DEVICE="${DEVICE:-'cuda'}"
LOG_FILENAME="${LOG_FILENAME:-"$LOGS_FOLDER/benchmark_vllm_$(date +'%Y%m%d%H%M%S').log"}"
MODELS_DIR="${MODELS_DIR:-"./models"}"
MODEL_NAME="${MODEL_NAME:-"llama"}"

run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR"
setup "$DEVICE" "$MODEL_NAME"
run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$MODEL_NAME"
Loading
Loading