From 10ae46f335c04914a9d616f2c6f627e2713370e5 Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 6 Sep 2024 10:38:48 +0300 Subject: [PATCH] add thread env arg to ollama check --- src/config/ollama.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/config/ollama.rs b/src/config/ollama.rs index 10cf574..3f495ef 100644 --- a/src/config/ollama.rs +++ b/src/config/ollama.rs @@ -1,7 +1,10 @@ use std::time::Duration; use ollama_workflows::{ - ollama_rs::{generation::completion::request::GenerationRequest, Ollama}, + ollama_rs::{ + generation::{completion::request::GenerationRequest, options::GenerationOptions}, + Ollama, + }, Model, }; @@ -174,12 +177,26 @@ impl OllamaConfig { return false; }; + let mut generation_request = + GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); + + // FIXME: temporary workaround, can take num threads from outside + if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") { + generation_request = generation_request.options( + GenerationOptions::default().num_thread( + num_thread + .parse() + .expect("num threads should be a positive integer"), + ), + ); + } + // then, run a sample generation with timeout and measure tps tokio::select! { _ = tokio::time::sleep(timeout) => { log::warn!("Ignoring model {}: Workflow timed out", model); }, - result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => { + result = ollama.generate(generation_request) => { match result { Ok(response) => { let tps = (response.eval_count.unwrap_or_default() as f64) @@ -189,7 +206,6 @@ impl OllamaConfig { if tps >= min_tps { log::info!("Model {} passed the test with tps: {}", model, tps); return true; - } log::warn!(