From 30b3e45d1ba809f4f4aaed6e52ca709b36437f9c Mon Sep 17 00:00:00 2001 From: thxCode Date: Fri, 19 Jul 2024 11:55:24 +0800 Subject: [PATCH] fix: conti-batching speculative sampling Signed-off-by: thxCode --- llama-box/main.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/llama-box/main.cpp b/llama-box/main.cpp index 91fc1a1..12e1465 100644 --- a/llama-box/main.cpp +++ b/llama-box/main.cpp @@ -2507,7 +2507,8 @@ struct server_context { for (int32_t j = 0; j < sz_draft + 1; ++j) { // greedy verification only bool accept = false; - tok = llama_sampling_sample(slot.ctx_sampling, ctx, nullptr, j); + tok = llama_sampling_sample(slot.ctx_sampling, ctx, nullptr, + slot.i_batch - i + j); llama_sampling_accept(slot.ctx_sampling, ctx, tok, true); slot.push_token_into_result(tok, result, ctx); if (j < sz_draft && tok == slot.sampled_draft[j]) { @@ -2766,13 +2767,14 @@ int main(int argc, char **argv) { llama_backend_init(); llama_numa_init(params.numa); - LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); - LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + LOG_INFO("build info", {{"version", LLAMA_BOX_GIT_VERSION}, + {"commit", LLAMA_BOX_GIT_COMMIT}, + {"llama_cpp_build", LLAMA_BUILD_NUMBER}, + {"llama_cpp_commit", LLAMA_COMMIT}}); + LOG_INFO("system info", {{"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}}); httplib::Server svr; std::atomic state{SERVER_STATE_LOADING_MODEL};