Skip to content

Commit

Permalink
[vicuna.py] Rework benchmark statistics calculation (#1992)
Browse files Browse the repository at this point in the history
- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
  • Loading branch information
kuhar authored Nov 23, 2023
1 parent da50a16 commit 2da31c4
Showing 1 changed file with 97 additions and 19 deletions.
116 changes: 97 additions & 19 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from dataclasses import dataclass
import json
import re
import gc
Expand Down Expand Up @@ -1937,6 +1938,87 @@ def create_prompt(model_name, history):
return msg


def miliseconds_to_seconds(ms: float) -> float:
return ms / 1000.0


@dataclass
class BenchmarkRunInfo:
num_prompt_tokens : int
prefill_time_ms : float
token_times_ms : list[float]

def get_prefill_speed(self) -> float:
seconds = miliseconds_to_seconds(self.prefill_time_ms)
if seconds == 0.0:
return float('inf')
return self.num_prompt_tokens / seconds

def num_generated_tokens(self) -> int:
return len(self.token_times_ms)

def get_decode_time_ms(self) -> float:
return sum(self.token_times_ms)

def get_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_decode_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds

def get_e2e_time_ms(self) -> float:
return self.prefill_time_ms + self.get_decode_time_ms()

def get_e2e_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds

def get_e2e_token_processing_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds

def print(self) -> None:
total_tokens = self.num_prompt_tokens + self.num_generated_tokens()
print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)")
print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s")
print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s")
print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)")


def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
num_iterations = len(run_infos)
print(f'Number of iterations: {num_iterations}')
if num_iterations == 0:
return

if len(run_infos) == 1:
run_infos[0].print()
return

total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens()
print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)")

def avg_and_stdev(data):
x = list(data)
return mean(x), stdev(x)

avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos)
avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos)
print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s")

avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos)
avg_decode_speed = mean(x.get_decode_speed() for x in run_infos)
print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s")

avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos)
avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos)
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")


if __name__ == "__main__":
args, unknown = parser.parse_known_args()

Expand Down Expand Up @@ -2035,8 +2117,7 @@ def create_prompt(model_name, history):

iteration = 0

prefill_times = []
avg_decode_speed = []
benchmark_run_infos = []

while True:
# TODO: Add break condition from user input
Expand All @@ -2052,35 +2133,32 @@ def create_prompt(model_name, history):
prompt = args.system_prompt + user_prompt
history = [[user_prompt, ""]]

token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
prompt_token_count = len(vic.tokenizer(prompt).input_ids)
total_time_ms = 0.0 # In order to avoid divide by zero error
prefill_time_ms = 0
is_first = True
token_times_ms = []

for text, msg, exec_time in vic.generate(prompt, cli=True):
if msg is None:
if is_first:
prefill_time = exec_time
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
prefill_time_ms = exec_time * 1000
is_first = False
else:
total_time_ms += exec_time
token_count += 1
token_times_ms.append(exec_time)
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
prefill_times.append(prefill_time)
avg_decode_speed.append(tokens_per_sec)

print("\nResponse:", text.strip())
print(f"\nNum tokens: {token_count}")
print(f"Prefill: {prefill_time:.2f} seconds")
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
print(f"\nResponse:\n{text.strip()}\n")
run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms)
run_info.print()
benchmark_run_infos.append(run_info)

else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)

if args.enable_microbenchmark:
print("\n### Final Statistics ###")
print("Number of iterations:", iteration - 1)
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
print_aggregate_stats(benchmark_run_infos)

0 comments on commit 2da31c4

Please sign in to comment.