diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 861776ec66..8b1f587f2a 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -290,10 +290,14 @@ void FlexFlow::top_level_task(Task const *task, /*ignore_comments */ true); std::vector requests; for (auto &prompt : prompt_json) { - std::string text = prompt.get(); - printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); + std::string text = prompt["prompt"].get(); + double slo_ratio = prompt["slo_ratio"].get(); + printf("Prompt[%d] with slo %.3f: %s\n", + total_num_requests, + slo_ratio, + text.c_str()); total_num_requests++; - requests.push_back(GenerationRequest(text, 1.0)); + requests.push_back(GenerationRequest(text, slo_ratio)); } ConstantEmissionMachine emission_machine(1.0); std::vector result = diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 7dc10ab388..25e3eb843c 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -471,10 +471,14 @@ void FlexFlow::top_level_task(Task const *task, std::vector requests; for (auto &prompt : prompt_json) { - std::string text = prompt.get(); - printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); + std::string text = prompt["prompt"].get(); + double slo_ratio = prompt["slo_ratio"].get(); + printf("Prompt[%d] with slo %.3f: %s\n", + total_num_requests, + slo_ratio, + text.c_str()); total_num_requests++; - requests.push_back(GenerationRequest(text, 1.0)); + requests.push_back(GenerationRequest(text, slo_ratio)); } ConstantEmissionMachine emission_machine(1.0); tree_model.generate(requests, emission_machine);