-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Llama2 inference benchmark under a new "benchmarks" section (#435)
* chore: added text-generation benchmark scripts * doc: added llama2 benchmark * Apply suggestions from code review Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> --------- Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
- Loading branch information
1 parent
1303aa4
commit 2709183
Showing
9 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Usage | ||
|
||
```shell | ||
python llama2.py | ||
``` | ||
This will produce several JSON files. | ||
|
||
```shell | ||
python gen_barchcharts.py <JSON files> | ||
``` | ||
|
||
This will create three barchart images for encoding times, latency and throughput. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import argparse | ||
import json | ||
import os | ||
import time | ||
|
||
import torch | ||
from transformers import AutoConfig, AutoTokenizer, set_seed | ||
|
||
from optimum.neuron import NeuronModelForCausalLM | ||
|
||
|
||
def generate(model, input_ids, length): | ||
start = time.time() | ||
with torch.inference_mode(): | ||
output_tokens = model.generate(input_ids, do_sample=False, min_length=length, max_length=length) | ||
end = time.time() | ||
return output_tokens, (end - start) | ||
|
||
|
||
def run(model_id, inc_length, max_length, json_path=None): | ||
prompts = ["One of my fondest memory"] | ||
config = AutoConfig.from_pretrained(model_id) | ||
batch_size = config.neuron["batch_size"] | ||
if len(prompts) < batch_size: | ||
prompts = prompts + [prompts[-1]] * (batch_size - len(prompts)) | ||
model = NeuronModelForCausalLM.from_pretrained(model_id, export=False, low_cpu_mem_usage=True) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
# Specify padding options for decoder-only architecture | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
tokenizer.padding_side = "left" | ||
# Encode the first input tokens | ||
tokens = tokenizer(prompts, return_tensors="pt", padding=True) | ||
bootstrap_input_ids = tokens.input_ids | ||
# Generate the first set of inputs | ||
input_ids, latency = generate(model, bootstrap_input_ids, inc_length) | ||
input_length = input_ids.size()[-1] | ||
neuron_config = getattr(model.config, "neuron") | ||
benchmark = {"neuron_config": neuron_config, "results": []} | ||
while input_length < max_length: | ||
# Generate a single input, just to evaluate the context encoding time | ||
_, encoding_time = generate(model, input_ids, input_length + 1) | ||
result = { | ||
"input_length": input_length, | ||
"batch_size": batch_size, | ||
"encoding_time": encoding_time, | ||
"generations": [], | ||
} | ||
for sequence_length in range(input_length + inc_length, max_length + 1, inc_length): | ||
output_ids, latency = generate(model, input_ids, sequence_length) | ||
throughput = batch_size * sequence_length / latency | ||
result["generations"].append( | ||
{ | ||
"sequence_length": sequence_length, | ||
"new_tokens": sequence_length - input_length, | ||
"latency": latency, | ||
"generation_time": latency - encoding_time, | ||
"throughput": throughput, | ||
} | ||
) | ||
# Reuse the first generated tokens for the next step | ||
input_length += inc_length | ||
input_ids = output_ids[:, :input_length] | ||
benchmark["results"].append(result) | ||
if json_path is not None: | ||
with open(json_path, "w") as fp: | ||
json.dump(benchmark, fp, indent=4) | ||
return benchmark | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("model", type=str, help="A neuron model in a local directory.") | ||
parser.add_argument("--inc-length", type=int, default=128, help="The number of tokens in each increment.") | ||
parser.add_argument("--max-length", type=int, default=2048, help="The maximum number of generated tokens.") | ||
parser.add_argument("--seed", type=int, default=None, help="Pass a seed for reproducibility.") | ||
args = parser.parse_args() | ||
if args.seed is not None: | ||
set_seed(args.seed) | ||
model_name = os.path.basename(os.path.normpath(args.model)) | ||
benchmark = run(args.model, args.inc_length, args.max_length, json_path=f"{model_name}.json") | ||
# Dump encoding times | ||
results = benchmark["results"] | ||
print(f"{benchmark['neuron_config']}") | ||
print("Encoding times") | ||
print([result["input_length"] for result in results]) | ||
print([f"{result['encoding_time']:.2f}" for result in results]) | ||
# Just look at the first set of generations | ||
generations = results[0]["generations"] | ||
print(f"Latency and throughput for {args.inc_length} input tokens") | ||
print([generation["new_tokens"] for generation in generations]) | ||
print([f"{generation['latency']:.2f}" for generation in generations]) | ||
print([f"{generation['throughput']:.2f}" for generation in generations]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import argparse | ||
import glob | ||
import json | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def save_bar_chart(title, labels, ylabel, series, save_path): | ||
x = np.arange(len(labels)) # the label locations | ||
width = 0.15 # the width of the bars | ||
multiplier = 0 | ||
|
||
fig, ax = plt.subplots(layout="constrained") | ||
fig.set_figwidth(10) | ||
|
||
max_value = 0 | ||
|
||
for attribute, measurement in series.items(): | ||
max_value = max(max_value, max(measurement)) | ||
offset = width * multiplier | ||
rects = ax.bar(x + offset, measurement, width, label=attribute) | ||
ax.bar_label(rects, padding=5) | ||
multiplier += 1 | ||
|
||
# Add some text for labels, title and custom x-axis tick labels, etc. | ||
ax.set_ylabel(ylabel) | ||
ax.set_title(title) | ||
ax.set_xticks(x + width, labels) | ||
ax.legend(loc="upper left", ncols=3) | ||
ax.set_ylim(0, max_value * 1.2) | ||
|
||
plt.savefig(save_path) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("inputs", type=str, nargs="*", help="A list of benchmark results files (.json).") | ||
args = parser.parse_args() | ||
inputs = args.inputs | ||
if len(inputs) == 0: | ||
inputs = glob.glob("*.json") | ||
benchmarks = {} | ||
for input in inputs: | ||
model_name = Path(input).stem | ||
with open(input) as f: | ||
benchmarks[model_name] = json.load(f) | ||
model_names = benchmarks.keys() | ||
# Generate encoding barchart | ||
input_length = [] | ||
encoding_times = {} | ||
for name in model_names: | ||
results = benchmarks[name]["results"] | ||
cur_input_length = [result["input_length"] for result in results] | ||
if len(input_length) == 0: | ||
input_length = cur_input_length | ||
else: | ||
assert cur_input_length == input_length, f"{name} does not have the same number of results" | ||
encoding_times[name] = [round(result["encoding_time"], 1) for result in results] | ||
save_bar_chart( | ||
title="Encoding time per input token", | ||
labels=input_length, | ||
series=encoding_times, | ||
ylabel="Encoding time (s)", | ||
save_path="encoding_times.png", | ||
) | ||
# Generate latency and throughput barcharts (for the first input length only) | ||
new_tokens = [] | ||
latencies = {} | ||
throughputs = {} | ||
for name in model_names: | ||
generations = benchmarks[name]["results"][0]["generations"] | ||
cur_new_tokens = [generation["new_tokens"] for generation in generations] | ||
if len(new_tokens) == 0: | ||
new_tokens = cur_new_tokens | ||
else: | ||
assert cur_new_tokens == new_tokens, f"{name} does not have the same number of results" | ||
latencies[name] = [round(generation["latency"], 1) for generation in generations] | ||
throughputs[name] = [round(generation["throughput"], 0) for generation in generations] | ||
save_bar_chart( | ||
title="End-to-end latency per generated tokens for 256 input tokens", | ||
labels=new_tokens, | ||
series=latencies, | ||
ylabel="Latency (s)", | ||
save_path="latencies.png", | ||
) | ||
save_bar_chart( | ||
title="Throughput per generated tokens for 256 input tokens", | ||
labels=new_tokens, | ||
series=throughputs, | ||
ylabel="Throughput (tokens/s)", | ||
save_path="throughputs.png", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
from tempfile import TemporaryDirectory | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from benchmark import run | ||
from optimum.neuron import NeuronModelForCausalLM | ||
|
||
|
||
model_configurations = { | ||
"Llama-2-7BL": ["meta-llama/Llama-2-7b-chat-hf", 1, 2048], | ||
"Llama-2-7BT": ["meta-llama/Llama-2-7b-chat-hf", 4, 2048], | ||
} | ||
|
||
num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2 | ||
if num_cores >= 4: | ||
extra_model_configurations = { | ||
"Llama-2-13BL": ["meta-llama/Llama-2-13b-chat-hf", 1, 2048], | ||
"Llama-2-13BT": ["meta-llama/Llama-2-13b-chat-hf", 4, 2048], | ||
} | ||
model_configurations = {**model_configurations, **extra_model_configurations} | ||
|
||
for model_name, model_configuration in model_configurations.items(): | ||
model_id, batch_size, seq_length = model_configuration | ||
model = NeuronModelForCausalLM.from_pretrained( | ||
model_id, export=True, batch_size=batch_size, sequence_length=seq_length, auto_cast_type="fp16" | ||
) | ||
with TemporaryDirectory() as tmpdir: | ||
model.save_pretrained(tmpdir) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
tokenizer.save_pretrained(tmpdir) | ||
json_path = f"{model_name}.json" | ||
run(tmpdir, 256, 1024, json_path=json_path) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
<!--- | ||
Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Llama performance on AWS Inferentia2 (Latency & Througput) | ||
|
||
How fast is Llama on Inferentia2? Let's figure out! | ||
|
||
For this benchmark we will use the LLama 2 7B and 13B models with different configurations: | ||
|
||
| Model type | num cores | batch_size | | ||
|----------------------------|-----------|------------| | ||
| Llama2 7B - L (latency) | 24 | 1 | | ||
| Llama2 7B - T (throughput) | 24 | 4 | | ||
| Llama2 13B - L (latency) | 24 | 1 | | ||
| Llama2 13B - T (throughput)| 24 | 4 | | ||
|
||
*Note: all models are compiled with a maximum sequence length of 2048.* | ||
|
||
All models are compiled to use the full extent of cores available on the `inf2.48xlarge` instance. | ||
|
||
*Note: please refer to the [inferentia2 product page](https://aws.amazon.com/ec2/instance-types/inf2/) for details on the available instances.* | ||
|
||
We created two "latency" oriented configurations for the `llama2 7B` and `llama2 13B` models that can serve only one request at a time, but at full speed and two "throughput" oriented configurations to serve up to four requests in parallel. | ||
|
||
To evaluate the models, we generate tokens up to a total sequence length of 1024, starting from | ||
256 input tokens (i.e. we generate 256, 512 and 768 tokens). | ||
|
||
## Encoding time (time to first token) | ||
|
||
The encoding time or time to first token is the time required to process the input tokens and generate the first output token. | ||
It is a very important metric, as it corresponds to the latency directly perceived by the user when streaming generated tokens. | ||
|
||
We test the encoding time for increasing context sizes, 256 input tokens corresponding roughly to a typical Q/A usage, | ||
while 768 is more typical of a Retrieval Augmented Generation (RAG) use-case. | ||
|
||
Encoding time is expressed in **seconds**. | ||
|
||
![Llama2 inferentia2 encoding-time](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/encoding-times.png "Encoding time") | ||
|
||
We can see that all deployed models exhibit excellent response times, even for long contexts. | ||
|
||
## End-to-end Latency | ||
|
||
The end-to-end latency corresponds to the total time to reach a sequence length of 1024 tokens. | ||
|
||
It therefore includes the encoding and generation time. | ||
|
||
Latency is expressed in **seconds**. | ||
|
||
![Llama2 inferentia2 end-to-end latency](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/latencies.png "Latency") | ||
|
||
All models deployed on the high-end instance exhibit a good latency, even those actually configured to optimize throughput. | ||
|
||
### Throughput | ||
|
||
We adopt the same convention as other benchmarks to evaluate the throughput, by dividing the end-to-end | ||
latency by the sum of both input and output tokens. | ||
In other words, we divide the end-to-end latency by `batch_size * sequence_length` to obtain the number of generated tokens per second. | ||
|
||
Throughput is expressed in **tokens/second**. | ||
|
||
![Llama2 inferentia2 throughput](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/throughputs.png "Throughput") | ||
|
||
Again, the models deployed on the high-end instance have a very good throughput, even those optimized for latency. |