Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Nov 1, 2024
1 parent d409d9d commit 4da8c8c
Showing 1 changed file with 10 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,20 @@ def generate(
key.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"key_cache_{layer}.bin"))
val.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"value_cache_{layer}.bin"))

token = input_id.to(torch.int32).item()
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))

if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]

time_t1 = time.perf_counter()
idx += 1

# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
self.vocab_size,
self.transpose_value_cache,
new_tokens - 2))
new_tokens - 1))
thread.start()

in_pipe_path = "\\\\.\\pipe\\llminputpipe"
Expand All @@ -146,7 +140,7 @@ def generate(
else:
break

time_start = time.perf_counter()
time_t2 = time.perf_counter()

bdata = str.encode(str(temp_dir))
invalidInputError(len(bdata) <= 2000,
Expand All @@ -162,6 +156,8 @@ def generate(
break
token = int.from_bytes(data, sys.byteorder)
idx += 1
if idx == 1:
time_t3 = time.perf_counter()
if token == eos:
break
output_tokens.append(torch.tensor([token]))
Expand All @@ -177,13 +173,14 @@ def generate(
time_end = time.perf_counter()

if do_print:
print(f" Start the thread and connect the pipe time: {(time_start - time_t1):.2f} s")
print(f" Start the thread and connect the pipe time: {(time_t2 - time_t1):.2f} s")
print(f" Number of input tokens: {input_length}")
print(f" Generated tokens: {idx}")
print(f" First token generation time: {(time_t1 - time_start_all):.2f} s")
print(f" Generation average latency: {(time_end - time_start) * 1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_start):.2f} token/s)")
print(f" Generation time: {(time_end - time_start_all - (time_start - time_t1)):.2f} s\n")
print(" First token generation time: "
f"{(time_t3 - time_start_all - (time_t2 - time_t1)):.2f} s")
print(f" Generation average latency: {(time_end - time_t3) * 1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_t3):.2f} token/s)")
print(f" Generation time: {(time_end - time_start_all - (time_t2 - time_t1)):.2f} s\n")
return output


Expand Down

0 comments on commit 4da8c8c

Please sign in to comment.