Skip to content

Commit

Permalink
add thread env arg to ollama check
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Sep 6, 2024
1 parent 2206702 commit 10ae46f
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/config/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::time::Duration;

use ollama_workflows::{
ollama_rs::{generation::completion::request::GenerationRequest, Ollama},
ollama_rs::{
generation::{completion::request::GenerationRequest, options::GenerationOptions},
Ollama,
},
Model,
};

Expand Down Expand Up @@ -174,12 +177,26 @@ impl OllamaConfig {
return false;
};

let mut generation_request =
GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string());

// FIXME: temporary workaround, can take num threads from outside
if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") {
generation_request = generation_request.options(
GenerationOptions::default().num_thread(
num_thread
.parse()
.expect("num threads should be a positive integer"),
),
);
}

// then, run a sample generation with timeout and measure tps
tokio::select! {
_ = tokio::time::sleep(timeout) => {
log::warn!("Ignoring model {}: Workflow timed out", model);
},
result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => {
result = ollama.generate(generation_request) => {
match result {
Ok(response) => {
let tps = (response.eval_count.unwrap_or_default() as f64)
Expand All @@ -189,7 +206,6 @@ impl OllamaConfig {
if tps >= min_tps {
log::info!("Model {} passed the test with tps: {}", model, tps);
return true;

}

log::warn!(
Expand Down

0 comments on commit 10ae46f

Please sign in to comment.