diff --git a/tank/examples/opt/opt_perf_comparison.py b/tank/examples/opt/opt_perf_comparison.py index 46cb7072f0..025a2cf40a 100644 --- a/tank/examples/opt/opt_perf_comparison.py +++ b/tank/examples/opt/opt_perf_comparison.py @@ -168,7 +168,7 @@ def save_json(data, filename): def collect_huggingface_logits( - model_name: str, max_seq_len: int, save_json: bool + model_name: str, max_seq_len: int, to_save_json: bool ) -> Tuple[float, float]: # Load t0 = time.time() @@ -194,11 +194,11 @@ def collect_huggingface_logits( for idx, tokens in enumerate(tokenized_prompts): print("prompt: {}".format(PROMPTS[idx])) logits = run_huggingface_model(model_wrapper, tokens) - if save_json: + if to_save_json: results.append([PROMPTS[idx], logits[0].tolist()]) run_time = time.time() - t0 print("--- Took {} seconds to run Huggingface.".format(run_time)) - if save_json: + if to_save_json: save_json(results, "/tmp/huggingface.json") run_memory_info = get_memory_info() return { @@ -215,7 +215,10 @@ def collect_huggingface_logits( def collect_shark_logits( - model_name: str, max_seq_len: int, recompile_shark: bool, save_json: bool + model_name: str, + max_seq_len: int, + recompile_shark: bool, + to_save_json: bool, ) -> Tuple[float, float]: # Load t0 = time.time() @@ -246,11 +249,11 @@ def collect_shark_logits( print("prompt: {}".format(PROMPTS[idx])) logits = run_shark_model(model_wrapper, tokens) lst = [e.tolist() for e in logits] - if save_json: + if to_save_json: results.append([PROMPTS[idx], lst]) run_time = time.time() - t0 print("--- Took {} seconds to run Shark.".format(run_time)) - if save_json: + if to_save_json: save_json(results, "/tmp/shark.json") platform_postfix = "-compile" if recompile_shark else "-precompiled" run_memory_info = get_memory_info()