Skip to content

Commit

Permalink
Add Llama2 inference benchmark under a new "benchmarks" section (#435)
Browse files Browse the repository at this point in the history
* chore: added text-generation benchmark scripts

* doc: added llama2 benchmark

* Apply suggestions from code review

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
  • Loading branch information
dacorvo and philschmid authored Jan 24, 2024
1 parent 1303aa4 commit 2709183
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 1 deletion.
12 changes: 12 additions & 0 deletions benchmark/text-generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Usage

```shell
python llama2.py
```
This will produce several JSON files.

```shell
python gen_barchcharts.py <JSON files>
```

This will create three barchart images for encoding times, latency and throughput.
92 changes: 92 additions & 0 deletions benchmark/text-generation/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import argparse
import json
import os
import time

import torch
from transformers import AutoConfig, AutoTokenizer, set_seed

from optimum.neuron import NeuronModelForCausalLM


def generate(model, input_ids, length):
start = time.time()
with torch.inference_mode():
output_tokens = model.generate(input_ids, do_sample=False, min_length=length, max_length=length)
end = time.time()
return output_tokens, (end - start)


def run(model_id, inc_length, max_length, json_path=None):
prompts = ["One of my fondest memory"]
config = AutoConfig.from_pretrained(model_id)
batch_size = config.neuron["batch_size"]
if len(prompts) < batch_size:
prompts = prompts + [prompts[-1]] * (batch_size - len(prompts))
model = NeuronModelForCausalLM.from_pretrained(model_id, export=False, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Specify padding options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
# Encode the first input tokens
tokens = tokenizer(prompts, return_tensors="pt", padding=True)
bootstrap_input_ids = tokens.input_ids
# Generate the first set of inputs
input_ids, latency = generate(model, bootstrap_input_ids, inc_length)
input_length = input_ids.size()[-1]
neuron_config = getattr(model.config, "neuron")
benchmark = {"neuron_config": neuron_config, "results": []}
while input_length < max_length:
# Generate a single input, just to evaluate the context encoding time
_, encoding_time = generate(model, input_ids, input_length + 1)
result = {
"input_length": input_length,
"batch_size": batch_size,
"encoding_time": encoding_time,
"generations": [],
}
for sequence_length in range(input_length + inc_length, max_length + 1, inc_length):
output_ids, latency = generate(model, input_ids, sequence_length)
throughput = batch_size * sequence_length / latency
result["generations"].append(
{
"sequence_length": sequence_length,
"new_tokens": sequence_length - input_length,
"latency": latency,
"generation_time": latency - encoding_time,
"throughput": throughput,
}
)
# Reuse the first generated tokens for the next step
input_length += inc_length
input_ids = output_ids[:, :input_length]
benchmark["results"].append(result)
if json_path is not None:
with open(json_path, "w") as fp:
json.dump(benchmark, fp, indent=4)
return benchmark


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model", type=str, help="A neuron model in a local directory.")
parser.add_argument("--inc-length", type=int, default=128, help="The number of tokens in each increment.")
parser.add_argument("--max-length", type=int, default=2048, help="The maximum number of generated tokens.")
parser.add_argument("--seed", type=int, default=None, help="Pass a seed for reproducibility.")
args = parser.parse_args()
if args.seed is not None:
set_seed(args.seed)
model_name = os.path.basename(os.path.normpath(args.model))
benchmark = run(args.model, args.inc_length, args.max_length, json_path=f"{model_name}.json")
# Dump encoding times
results = benchmark["results"]
print(f"{benchmark['neuron_config']}")
print("Encoding times")
print([result["input_length"] for result in results])
print([f"{result['encoding_time']:.2f}" for result in results])
# Just look at the first set of generations
generations = results[0]["generations"]
print(f"Latency and throughput for {args.inc_length} input tokens")
print([generation["new_tokens"] for generation in generations])
print([f"{generation['latency']:.2f}" for generation in generations])
print([f"{generation['throughput']:.2f}" for generation in generations])
98 changes: 98 additions & 0 deletions benchmark/text-generation/gen_barcharts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import glob
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def save_bar_chart(title, labels, ylabel, series, save_path):
x = np.arange(len(labels)) # the label locations
width = 0.15 # the width of the bars
multiplier = 0

fig, ax = plt.subplots(layout="constrained")
fig.set_figwidth(10)

max_value = 0

for attribute, measurement in series.items():
max_value = max(max_value, max(measurement))
offset = width * multiplier
rects = ax.bar(x + offset, measurement, width, label=attribute)
ax.bar_label(rects, padding=5)
multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.set_xticks(x + width, labels)
ax.legend(loc="upper left", ncols=3)
ax.set_ylim(0, max_value * 1.2)

plt.savefig(save_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("inputs", type=str, nargs="*", help="A list of benchmark results files (.json).")
args = parser.parse_args()
inputs = args.inputs
if len(inputs) == 0:
inputs = glob.glob("*.json")
benchmarks = {}
for input in inputs:
model_name = Path(input).stem
with open(input) as f:
benchmarks[model_name] = json.load(f)
model_names = benchmarks.keys()
# Generate encoding barchart
input_length = []
encoding_times = {}
for name in model_names:
results = benchmarks[name]["results"]
cur_input_length = [result["input_length"] for result in results]
if len(input_length) == 0:
input_length = cur_input_length
else:
assert cur_input_length == input_length, f"{name} does not have the same number of results"
encoding_times[name] = [round(result["encoding_time"], 1) for result in results]
save_bar_chart(
title="Encoding time per input token",
labels=input_length,
series=encoding_times,
ylabel="Encoding time (s)",
save_path="encoding_times.png",
)
# Generate latency and throughput barcharts (for the first input length only)
new_tokens = []
latencies = {}
throughputs = {}
for name in model_names:
generations = benchmarks[name]["results"][0]["generations"]
cur_new_tokens = [generation["new_tokens"] for generation in generations]
if len(new_tokens) == 0:
new_tokens = cur_new_tokens
else:
assert cur_new_tokens == new_tokens, f"{name} does not have the same number of results"
latencies[name] = [round(generation["latency"], 1) for generation in generations]
throughputs[name] = [round(generation["throughput"], 0) for generation in generations]
save_bar_chart(
title="End-to-end latency per generated tokens for 256 input tokens",
labels=new_tokens,
series=latencies,
ylabel="Latency (s)",
save_path="latencies.png",
)
save_bar_chart(
title="Throughput per generated tokens for 256 input tokens",
labels=new_tokens,
series=throughputs,
ylabel="Throughput (tokens/s)",
save_path="throughputs.png",
)


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions benchmark/text-generation/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
from tempfile import TemporaryDirectory

from transformers import AutoTokenizer

from benchmark import run
from optimum.neuron import NeuronModelForCausalLM


model_configurations = {
"Llama-2-7BL": ["meta-llama/Llama-2-7b-chat-hf", 1, 2048],
"Llama-2-7BT": ["meta-llama/Llama-2-7b-chat-hf", 4, 2048],
}

num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2
if num_cores >= 4:
extra_model_configurations = {
"Llama-2-13BL": ["meta-llama/Llama-2-13b-chat-hf", 1, 2048],
"Llama-2-13BT": ["meta-llama/Llama-2-13b-chat-hf", 4, 2048],
}
model_configurations = {**model_configurations, **extra_model_configurations}

for model_name, model_configuration in model_configurations.items():
model_id, batch_size, seq_length = model_configuration
model = NeuronModelForCausalLM.from_pretrained(
model_id, export=True, batch_size=batch_size, sequence_length=seq_length, auto_cast_type="fp16"
)
with TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(tmpdir)
json_path = f"{model_name}.json"
run(tmpdir, 256, 1024, json_path=json_path)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
- local: tutorials/llama2-13b-chatbot
title: Create your own chatbot with llama-2-13B on AWS Inferentia
- local: tutorials/fine_tune_llama_7b
title: Fine-tune Llama 2 7B on AWS Trainium
title: Fine-tune Llama 2 7B on AWS Trainium
- local: tutorials/sentence_transformers
title: Sentence Transformers on AWS Inferentia
title: Tutorials
Expand Down Expand Up @@ -51,5 +51,9 @@
- local: package_reference/modeling
title: Neuron Models
title: Reference
- sections:
- local: benchmarks/inferentia-llama2
title: Llama on AWS Inferentia2
title: Benchmarks
title: Optimum Neuron
isExpanded: true
77 changes: 77 additions & 0 deletions docs/source/benchmarks/inferentia-llama2.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Llama performance on AWS Inferentia2 (Latency & Througput)

How fast is Llama on Inferentia2? Let's figure out!

For this benchmark we will use the LLama 2 7B and 13B models with different configurations:

| Model type | num cores | batch_size |
|----------------------------|-----------|------------|
| Llama2 7B - L (latency) | 24 | 1 |
| Llama2 7B - T (throughput) | 24 | 4 |
| Llama2 13B - L (latency) | 24 | 1 |
| Llama2 13B - T (throughput)| 24 | 4 |

*Note: all models are compiled with a maximum sequence length of 2048.*

All models are compiled to use the full extent of cores available on the `inf2.48xlarge` instance.

*Note: please refer to the [inferentia2 product page](https://aws.amazon.com/ec2/instance-types/inf2/) for details on the available instances.*

We created two "latency" oriented configurations for the `llama2 7B` and `llama2 13B` models that can serve only one request at a time, but at full speed and two "throughput" oriented configurations to serve up to four requests in parallel.

To evaluate the models, we generate tokens up to a total sequence length of 1024, starting from
256 input tokens (i.e. we generate 256, 512 and 768 tokens).

## Encoding time (time to first token)

The encoding time or time to first token is the time required to process the input tokens and generate the first output token.
It is a very important metric, as it corresponds to the latency directly perceived by the user when streaming generated tokens.

We test the encoding time for increasing context sizes, 256 input tokens corresponding roughly to a typical Q/A usage,
while 768 is more typical of a Retrieval Augmented Generation (RAG) use-case.

Encoding time is expressed in **seconds**.

![Llama2 inferentia2 encoding-time](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/encoding-times.png "Encoding time")

We can see that all deployed models exhibit excellent response times, even for long contexts.

## End-to-end Latency

The end-to-end latency corresponds to the total time to reach a sequence length of 1024 tokens.

It therefore includes the encoding and generation time.

Latency is expressed in **seconds**.

![Llama2 inferentia2 end-to-end latency](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/latencies.png "Latency")

All models deployed on the high-end instance exhibit a good latency, even those actually configured to optimize throughput.

### Throughput

We adopt the same convention as other benchmarks to evaluate the throughput, by dividing the end-to-end
latency by the sum of both input and output tokens.
In other words, we divide the end-to-end latency by `batch_size * sequence_length` to obtain the number of generated tokens per second.

Throughput is expressed in **tokens/second**.

![Llama2 inferentia2 throughput](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/throughputs.png "Throughput")

Again, the models deployed on the high-end instance have a very good throughput, even those optimized for latency.

0 comments on commit 2709183

Please sign in to comment.