Skip to content

Commit

Permalink
Merge branch 'dev' into trtllm-mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
Anindyadeep authored Apr 22, 2024
2 parents cb6ea71 + 2033b74 commit deae4c1
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 497 deletions.
48 changes: 38 additions & 10 deletions bench_exllamav2/README.md

Large diffs are not rendered by default.

210 changes: 121 additions & 89 deletions bench_exllamav2/bench.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,147 @@
import argparse
import logging
import os
import sys
import time
from collections import defaultdict

import numpy as np
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Cache
from exllamav2.config import ExLlamaV2Config
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler
from exllamav2.tokenizer.tokenizer import ExLlamaV2Tokenizer
from transformers import AutoTokenizer

logging.getLogger("llama_cpp").setLevel(logging.ERROR)
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
sys.path.append(os.getcwd())

from common.base import BaseBenchmarkClass # noqa
from common.utils import launch_cli, make_report # noqa

class ExllamaV2Benchmark:
def __init__(self, model_path: str) -> None:
self.model_path, self.results = model_path, []

def load_model(self):
class ExLlamaV2Benchmark(BaseBenchmarkClass):
def __init__(
self,
model_path: str,
model_name: str,
benchmark_name: str,
precision: str,
device: str,
experiment_name: str,
) -> None:
assert precision in ["int8", "int4"], ValueError(
"Available precision: 'int8', 'int4'"
)
super().__init__(
model_name=model_name,
model_path=model_path,
benchmark_name=benchmark_name,
experiment_name=experiment_name,
precision=precision,
device=device,
)

def load_model_and_tokenizer(self):
# set up model config
self.config = ExLlamaV2Config()
self.config.model_dir = self.model_path
self.config.prepare()

self.model = ExLlamaV2(self.config)
self.cache = ExLlamaV2Cache(self.model, lazy=True)
self.model.load_autosplit(self.cache)
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# set up model and cache
self._model = ExLlamaV2(self.config)
self.cache = ExLlamaV2Cache(self._model, lazy=True)
self._model.load_autosplit(self.cache)
self.tokenizer_exllama = ExLlamaV2Tokenizer(self.config)
self.model = ExLlamaV2BaseGenerator(
self._model, self.cache, self.tokenizer_exllama
)
self.model.warmup()

self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
# set up the huggingface tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

# set up exllamav2 settings
self.settings = ExLlamaV2Sampler.Settings()
self.settings.temperature = 0.85
self.settings.top_k = 50
self.settings.top_p = 0.8
self.settings.token_repetition_penalty = 1.05
self.settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
self.generator.warmup()
self.settings.disallow_tokens(
self.tokenizer_exllama, [self.tokenizer_exllama.eos_token_id]
)
return self

@torch.inference_mode()
def run_model(self, prompt: str, max_tokens: int) -> float:
start = time.time()
_ = self.generator.generate_simple(prompt, self.settings, max_tokens, seed=1234)
delta = time.time() - start
return len(self.generator.sequence_ids[0]) / delta

def benchmark(self, prompt: str, max_tokens: int, repetitions: int) -> None:
for i in range(repetitions):
logging.info(
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]"
def preprocess(
self, prompt: str, chat_mode: bool = True, for_benchmarks: bool = True
):
if chat_mode:
template = self.get_chat_template_with_instruction(
prompt=prompt, for_benchmarks=for_benchmarks
)
tokens_per_second = self.run_model(prompt, max_tokens)
self.results.append(tokens_per_second)
prompt = self.tokenizer.apply_chat_template(template, tokenize=False)
tokenized_input = self.tokenizer.encode(text=prompt)
return {
"prompt": prompt,
"input_tokens": tokenized_input,
"tensor": None,
"num_input_tokens": len(tokenized_input),
}

def run_model(self, inputs: dict, max_tokens: int, temperature: float) -> dict:
# first set up the settings
self.settings.token_repetition_penalty = 1.01
self.settings.temperature = temperature
self.settings.top_k = 50
self.settings.top_p = 0.1

# now run the model
prompt = inputs["prompt"]
output_text = self.model.generate_simple(
prompt,
self.settings,
max_tokens,
seed=1234,
completion_only=True,
decode_special_tokens=True,
)

tokenized_output = self.tokenizer.encode(output_text)
return {
"output_text": output_text,
"output_tokens": tokenized_output,
"num_output_tokens": len(tokenized_output),
}

def postprocess(self, output: dict) -> str:
return output["output_text"]

def on_exit(self):
if self.device == "cuda":
del self.model
torch.cuda.synchronize()
else:
del self.model


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="llama.cpp Benchmark Llama model.")
parser.add_argument(
"--prompt",
type=str,
help="The prompt for the model.",
)
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.")
parser.add_argument(
"--repetitions",
type=int,
help="The number of repetitions for the benchmark.",
)
parser.add_argument(
"--log_file",
type=str,
help="Path to the log file for writing logs (in append mode).",
)
parser.add_argument(
"--models_dir",
type=str,
help="Path to the models directory.",
)
parser = launch_cli(description="ExLlamaV2 Benchmark.")
args = parser.parse_args()
logging.info(
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} "
+ f"repetitions={args.repetitions} device=cuda"

model_folder = os.path.join(os.getcwd(), "models")
model_name = (
f"{args.model_name}-2-7b-chat-exllamav2-"
if args.model_name == "llama"
else f"{args.model_name}-7b-v0.1-instruct-exllamav2-"
)
report = defaultdict(lambda: defaultdict(float))
for quantize in ("q8", "q4"):
logging.info(f"Running ExllamaV2 benchmark with {quantize}")
llamacpp_bench = ExllamaV2Benchmark(
f"{args.models_dir}/llama-2-7b-exllamav2-{quantize}"
).load_model()
llamacpp_bench.benchmark(
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions
)
q = "int8" if quantize == "q8" else "int4"
report["exllamav2"][q] = {
"mean": np.mean(llamacpp_bench.results),
"std": np.std(llamacpp_bench.results),
}

logging.info("Benchmark report")
with open(args.log_file, "a") as file:
for framework, quantizations in report.items():
for quantization, stats in quantizations.items():
logging.info(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}"
)
print(
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}",
file=file,
)
runner_dict = {
"cuda": [
{
"precision": "int4",
"model_path": os.path.join(model_folder, model_name + "4.0-bit"),
},
{
"precision": "int8",
"model_path": os.path.join(model_folder, model_name + "8.0-bit"),
},
]
}

make_report(
args=args,
benchmark_class=ExLlamaV2Benchmark,
runner_dict=runner_dict,
benchmark_name="ExLlamaV2",
is_bench_pytorch=False,
)
61 changes: 21 additions & 40 deletions bench_exllamav2/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,36 @@

########################################################################################################
# Script: bench.sh
# Description: This script runs benchmarks Exllamav2 Llama-2 benchmark.
# Description: This script runs benchmarks ExLlamaV2 benchmark.
#
# Usage: ./bench.sh [OPTIONS]
# OPTIONS:
# -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')
# -r, --repetitions Number of repetitions for benchmarks (default: 10)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)
# -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')
# -lf, --log_file Logging file name.
# -md, --models_dir Models directory.
# -h, --help Show this help message
# -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')
# -r, --repetitions Number of repetitions for benchmarks (default: 10)
# -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)
# -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')
# -n, --model_name The name of the model to benchmark (possible values: 'llama' for using Llama2, 'mistral' for using Mistral 7B v0.1)
# -lf, --log_file Logging file name.
# -h, --help Show this help message
########################################################################################################

set -euo pipefail

CURRENT_DIR="$(pwd)"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "OPTIONS:"
echo " -p, --prompt Prompt for benchmarks (default: 'Write an essay about the transformer model architecture')"
echo " -r, --repetitions Number of repetitions for benchmarks (default: 10)"
echo " -m, --max_tokens Maximum number of tokens for benchmarks (default: 512)"
echo " -d, --device Device for benchmarks (possible values: 'metal', 'cuda', and 'cpu', default: 'cuda')"
echo " -n, --model_name The name of the model to benchmark (possible values: 'llama' for using Llama2, 'mistral' for using Mistral 7B v0.1)"
echo " -lf, --log_file Logging file name."
echo " -md, --models_dir Models directory."
echo " -h, --help Show this help message"
exit 1
}

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

check_cuda() {
if command -v nvcc &> /dev/null
then
Expand Down Expand Up @@ -69,42 +68,29 @@ check_python() {
}

setup() {

# Check if Logs folder exists else Make the logs folder
LOGS_FOLDER="$CURRENT_DIR/Logs"

if [ -d "$LOGS_FOLDER" ]; then
echo "Folder '$LOGS_FOLDER' already exists. Skipping."
else
# Create the folder
mkdir "$LOGS_FOLDER"
echo "'$LOGS_FOLDER' created."
fi

local MODEL_NAME="${1:-llama}"
echo -e "\nSetting up with $SCRIPT_DIR/setup.sh..."
bash "$SCRIPT_DIR"/setup.sh
bash "$SCRIPT_DIR/setup.sh" "$MODEL_NAME"
}

run_benchmarks() {
local PROMPT="$1"
local REPETITIONS="$2"
local MAX_TOKENS="$3"
local DEVICE="$4"
local LOG_FILENAME="$5"
local MODELS_DIR="$6"
local MODEL_NAME="$5"

# shellcheck disable=SC1091
source "$SCRIPT_DIR/venv/bin/activate"
"$PYTHON_CMD" "$SCRIPT_DIR"/bench.py \
--prompt "$PROMPT" \
--repetitions "$REPETITIONS" \
--max_tokens "$MAX_TOKENS" \
--log_file "$LOG_FILENAME" \
--models_dir "$MODELS_DIR"
--model_name "$MODEL_NAME" \
--device "$DEVICE"
}


# Parse command-line arguments
while [ "$#" -gt 0 ]; do
case "$1" in
-p|--prompt)
Expand Down Expand Up @@ -137,12 +123,8 @@ while [ "$#" -gt 0 ]; do
fi
shift 2
;;
-lf|--log_file)
LOG_FILENAME="$2"
shift 2
;;
-md|--models_dir)
MODELS_DIR="$2"
-n|--model_name)
MODEL_NAME="$2"
shift 2
;;
-h|--help)
Expand All @@ -157,14 +139,13 @@ done

check_platform
check_python
setup
setup "$MODEL_NAME"

# Set default values if not provided
PROMPT="${PROMPT:-"Write an essay about the transformer model architecture"}"
REPETITIONS="${REPETITIONS:-10}"
MAX_TOKENS="${MAX_TOKENS:-512}"
DEVICE="${DEVICE:-'cuda'}"
LOG_FILENAME="${LOG_FILENAME:-"$LOGS_FOLDER/benchmark_exllamav2_$(date +'%Y%m%d%H%M%S').log"}"
MODELS_DIR="${MODELS_DIR:-"./models"}"
MODEL_NAME="${MODEL_NAME:-"llama"}"

run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR"
run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$MODEL_NAME"
Loading

0 comments on commit deae4c1

Please sign in to comment.