From 263894a74006e6b1b20190284030a97bcfe5c368 Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 5 Sep 2024 20:34:10 +0300 Subject: [PATCH] added tps check --- Makefile | 4 ++ docs/NODE_GUIDE.md | 4 +- src/config/mod.rs | 8 +++- src/config/models.rs | 12 ++--- src/config/ollama.rs | 110 ++++++++++++++++++------------------------- 5 files changed, 65 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index 618bf31..fd07db0 100644 --- a/Makefile +++ b/Makefile @@ -37,6 +37,10 @@ profile-mem: version: @cargo pkgid | cut -d@ -f2 +.PHONY: ollama-cpu # | Run Ollama CPU container +ollama-cpu: + docker run -p=11434:11434 -v=${HOME}/.ollama:/root/.ollama ollama/ollama + ############################################################################### .PHONY: test # | Run tests test: diff --git a/docs/NODE_GUIDE.md b/docs/NODE_GUIDE.md index 8f368e9..2906c5e 100644 --- a/docs/NODE_GUIDE.md +++ b/docs/NODE_GUIDE.md @@ -192,8 +192,6 @@ For the models that you choose (see list of models just below [here](#1-choose-m ollama pull llama3.1:latest ``` -> [!TIP] - #### Optional Services Based on presence of API keys, [Ollama Workflows](https://github.com/andthattoo/ollama-workflows/) may use more superior services instead of free alternatives, e.g. [Serper](https://serper.dev/) instead of [DuckDuckGo](https://duckduckgo.com/) or [Jina](https://jina.ai/) without rate-limit instead of with rate-limit. Add these within your `.env` as: @@ -213,7 +211,7 @@ Based on the resources of your machine, you must decide which models that you wi #### Ollama Models -- `adrienbrault/nous-hermes2theta-llama3-8b:q8_0` +- `finalend/hermes-3-llama-3.1:8b-q8_0` - `phi3:14b-medium-4k-instruct-q4_1` - `phi3:14b-medium-128k-instruct-q4_1` - `phi3.5:3.8b` diff --git a/src/config/mod.rs b/src/config/mod.rs index 75c8339..d9db96c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,6 +11,12 @@ use openai::OpenAIConfig; use std::{env, time::Duration}; +/// Timeout duration for checking model performance during a generation. +const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80); + +/// Minimum tokens per second (TPS) for checking model performance during a generation. +const CHECK_TPS: f64 = 5.0; + #[derive(Debug, Clone)] pub struct DriaComputeNodeConfig { /// Wallet secret/private key. @@ -139,7 +145,7 @@ impl DriaComputeNodeConfig { // ensure that the models are pulled / pull them if not let good_ollama_models = self .ollama_config - .check(ollama_models, Duration::from_secs(30)) + .check(ollama_models, CHECK_TIMEOUT_DURATION, CHECK_TPS) .await?; good_models.extend( good_ollama_models diff --git a/src/config/models.rs b/src/config/models.rs index aa3dfba..e89b39a 100644 --- a/src/config/models.rs +++ b/src/config/models.rs @@ -136,14 +136,14 @@ mod tests { assert_eq!(cfg.models.len(), 0); let cfg = ModelConfig::new_from_csv(Some( - "phi3:3.8b,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), + "gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(), )); assert_eq!(cfg.models.len(), 2); } #[test] fn test_model_matching() { - let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string())); + let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); assert_eq!( cfg.get_matching_model("openai".to_string()).unwrap().1, Model::GPT3_5Turbo, @@ -151,10 +151,10 @@ mod tests { ); assert_eq!( - cfg.get_matching_model(Model::default().to_string()) + cfg.get_matching_model("llama3.1:latest".to_string()) .unwrap() .1, - Model::default(), + Model::Llama3_1_8B, "Should find existing model" ); @@ -172,7 +172,7 @@ mod tests { #[test] fn test_get_any_matching_model() { - let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string())); + let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string())); let result = cfg.get_any_matching_model(vec![ "i-dont-exist".to_string(), "llama3.1:latest".to_string(), @@ -181,7 +181,7 @@ mod tests { ]); assert_eq!( result.unwrap().1, - Model::default(), + Model::Llama3_1_8B, "Should find existing model" ); } diff --git a/src/config/ollama.rs b/src/config/ollama.rs index 4e82baa..10cf574 100644 --- a/src/config/ollama.rs +++ b/src/config/ollama.rs @@ -1,6 +1,9 @@ use std::time::Duration; -use ollama_workflows::{ollama_rs::Ollama, Executor, Model, ProgramMemory, Workflow}; +use ollama_workflows::{ + ollama_rs::{generation::completion::request::GenerationRequest, Ollama}, + Model, +}; const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; const DEFAULT_OLLAMA_PORT: u16 = 11434; @@ -8,6 +11,9 @@ const DEFAULT_OLLAMA_PORT: u16 = 11434; /// Some models such as small embedding models, are hardcoded into the node. const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"]; +/// Prompt to be used to see Ollama performance. +const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; + /// Ollama-specific configurations. #[derive(Debug, Clone)] pub struct OllamaConfig { @@ -66,12 +72,13 @@ impl OllamaConfig { pub async fn check( &self, external_models: Vec, - test_workflow_timeout: Duration, + timeout: Duration, + min_tps: f64, ) -> Result, String> { log::info!( "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)", if self.auto_pull { "on" } else { "off" }, - test_workflow_timeout.as_secs() + timeout.as_secs() ); let ollama = Ollama::new(&self.host, self.port); @@ -108,7 +115,7 @@ impl OllamaConfig { } if self - .test_workflow(model.clone(), test_workflow_timeout) + .test_performance(&ollama, &model, timeout, min_tps) .await { good_models.push(model); @@ -149,71 +156,48 @@ impl OllamaConfig { /// /// This is to see if a given system can execute Ollama workflows for their chosen models, /// e.g. if they have enough RAM/CPU and such. - pub async fn test_workflow(&self, model: Model, timeout: Duration) -> bool { - // this is the test workflow that we will run - // TODO: when Workflow's have `Clone`, we can remove the repetitive parsing here - let workflow = serde_json::from_value::(serde_json::json!({ - "name": "Simple", - "description": "This is a simple workflow", - "config":{ - "max_steps": 5, - "max_time": 100, - "max_tokens": 100, - "tools": [] - }, - "tasks":[ - { - "id": "A", - "name": "Random Poem", - "description": "Writes a poem about Kapadokya.", - "prompt": "Please write a poem about Kapadokya.", - "inputs":[], - "operator": "generation", - "outputs":[ - { - "type": "write", - "key": "poem", - "value": "__result" - } - ] - }, - { - "id": "__end", - "name": "end", - "description": "End of the task", - "prompt": "End of the task", - "inputs": [], - "operator": "end", - "outputs": [] - } - ], - "steps":[ - { - "source":"A", - "target":"end" - } - ], - "return_value":{ - "input":{ - "type": "read", - "key": "poem" - } - } - })) - .expect("Preset workflow should be parsed"); - + pub async fn test_performance( + &self, + ollama: &Ollama, + model: &Model, + timeout: Duration, + min_tps: f64, + ) -> bool { log::info!("Testing model {}", model); - let executor = Executor::new_at(model.clone(), &self.host, self.port); - let mut memory = ProgramMemory::new(); + + // first generate a dummy embedding to load the model into memory (warm-up) + if let Err(err) = ollama + .generate_embeddings(model.to_string(), "foobar".to_string(), Default::default()) + .await + { + log::error!("Failed to generate embedding for model {}: {}", model, err); + return false; + }; + + // then, run a sample generation with timeout and measure tps tokio::select! { _ = tokio::time::sleep(timeout) => { log::warn!("Ignoring model {}: Workflow timed out", model); }, - result = executor.execute(None, workflow, &mut memory) => { + result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => { match result { - Ok(_) => { - log::info!("Accepting model {}", model); - return true; + Ok(response) => { + let tps = (response.eval_count.unwrap_or_default() as f64) + / (response.eval_duration.unwrap_or(1) as f64) + * 1_000_000_000f64; + + if tps >= min_tps { + log::info!("Model {} passed the test with tps: {}", model, tps); + return true; + + } + + log::warn!( + "Ignoring model {}: tps too low ({:.3} < {:.3})", + model, + tps, + min_tps + ); } Err(e) => { log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);