diff --git a/src/config/ollama.rs b/src/config/ollama.rs index 10cf574..3f495ef 100644 --- a/src/config/ollama.rs +++ b/src/config/ollama.rs @@ -1,7 +1,10 @@ use std::time::Duration; use ollama_workflows::{ - ollama_rs::{generation::completion::request::GenerationRequest, Ollama}, + ollama_rs::{ + generation::{completion::request::GenerationRequest, options::GenerationOptions}, + Ollama, + }, Model, }; @@ -174,12 +177,26 @@ impl OllamaConfig { return false; }; + let mut generation_request = + GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); + + // FIXME: temporary workaround, can take num threads from outside + if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") { + generation_request = generation_request.options( + GenerationOptions::default().num_thread( + num_thread + .parse() + .expect("num threads should be a positive integer"), + ), + ); + } + // then, run a sample generation with timeout and measure tps tokio::select! { _ = tokio::time::sleep(timeout) => { log::warn!("Ignoring model {}: Workflow timed out", model); }, - result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => { + result = ollama.generate(generation_request) => { match result { Ok(response) => { let tps = (response.eval_count.unwrap_or_default() as f64) @@ -189,7 +206,6 @@ impl OllamaConfig { if tps >= min_tps { log::info!("Model {} passed the test with tps: {}", model, tps); return true; - } log::warn!(