Skip to content

Commit

Permalink
Rename the func arg save_json to avoid name collision. (#1837)
Browse files Browse the repository at this point in the history
* Rename the func arg save_json to avoid name collision.

* black formatted.
  • Loading branch information
godot73 authored Sep 19, 2023
1 parent ded74d0 commit b5fbb1a
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tank/examples/opt/opt_perf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def save_json(data, filename):


def collect_huggingface_logits(
model_name: str, max_seq_len: int, save_json: bool
model_name: str, max_seq_len: int, to_save_json: bool
) -> Tuple[float, float]:
# Load
t0 = time.time()
Expand All @@ -194,11 +194,11 @@ def collect_huggingface_logits(
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
if save_json:
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if save_json:
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
Expand All @@ -215,7 +215,10 @@ def collect_huggingface_logits(


def collect_shark_logits(
model_name: str, max_seq_len: int, recompile_shark: bool, save_json: bool
model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
Expand Down Expand Up @@ -246,11 +249,11 @@ def collect_shark_logits(
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
if save_json:
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if save_json:
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
Expand Down

0 comments on commit b5fbb1a

Please sign in to comment.