Skip to content

Commit

Permalink
added tps check
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Sep 5, 2024
1 parent 71c1078 commit 263894a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 73 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ profile-mem:
version:
@cargo pkgid | cut -d@ -f2

.PHONY: ollama-cpu # | Run Ollama CPU container
ollama-cpu:
docker run -p=11434:11434 -v=${HOME}/.ollama:/root/.ollama ollama/ollama

###############################################################################
.PHONY: test # | Run tests
test:
Expand Down
4 changes: 1 addition & 3 deletions docs/NODE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ For the models that you choose (see list of models just below [here](#1-choose-m
ollama pull llama3.1:latest
```

> [!TIP]
#### Optional Services

Based on presence of API keys, [Ollama Workflows](https://github.com/andthattoo/ollama-workflows/) may use more superior services instead of free alternatives, e.g. [Serper](https://serper.dev/) instead of [DuckDuckGo](https://duckduckgo.com/) or [Jina](https://jina.ai/) without rate-limit instead of with rate-limit. Add these within your `.env` as:
Expand All @@ -213,7 +211,7 @@ Based on the resources of your machine, you must decide which models that you wi

#### Ollama Models

- `adrienbrault/nous-hermes2theta-llama3-8b:q8_0`
- `finalend/hermes-3-llama-3.1:8b-q8_0`
- `phi3:14b-medium-4k-instruct-q4_1`
- `phi3:14b-medium-128k-instruct-q4_1`
- `phi3.5:3.8b`
Expand Down
8 changes: 7 additions & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ use openai::OpenAIConfig;

use std::{env, time::Duration};

/// Timeout duration for checking model performance during a generation.
const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80);

/// Minimum tokens per second (TPS) for checking model performance during a generation.
const CHECK_TPS: f64 = 5.0;

#[derive(Debug, Clone)]
pub struct DriaComputeNodeConfig {
/// Wallet secret/private key.
Expand Down Expand Up @@ -139,7 +145,7 @@ impl DriaComputeNodeConfig {
// ensure that the models are pulled / pull them if not
let good_ollama_models = self
.ollama_config
.check(ollama_models, Duration::from_secs(30))
.check(ollama_models, CHECK_TIMEOUT_DURATION, CHECK_TPS)
.await?;
good_models.extend(
good_ollama_models
Expand Down
12 changes: 6 additions & 6 deletions src/config/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,25 @@ mod tests {
assert_eq!(cfg.models.len(), 0);

let cfg = ModelConfig::new_from_csv(Some(
"phi3:3.8b,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(),
"gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(),
));
assert_eq!(cfg.models.len(), 2);
}

#[test]
fn test_model_matching() {
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string()));
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string()));
assert_eq!(
cfg.get_matching_model("openai".to_string()).unwrap().1,
Model::GPT3_5Turbo,
"Should find existing model"
);

assert_eq!(
cfg.get_matching_model(Model::default().to_string())
cfg.get_matching_model("llama3.1:latest".to_string())
.unwrap()
.1,
Model::default(),
Model::Llama3_1_8B,
"Should find existing model"
);

Expand All @@ -172,7 +172,7 @@ mod tests {

#[test]
fn test_get_any_matching_model() {
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string()));
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string()));
let result = cfg.get_any_matching_model(vec![
"i-dont-exist".to_string(),
"llama3.1:latest".to_string(),
Expand All @@ -181,7 +181,7 @@ mod tests {
]);
assert_eq!(
result.unwrap().1,
Model::default(),
Model::Llama3_1_8B,
"Should find existing model"
);
}
Expand Down
110 changes: 47 additions & 63 deletions src/config/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
use std::time::Duration;

use ollama_workflows::{ollama_rs::Ollama, Executor, Model, ProgramMemory, Workflow};
use ollama_workflows::{
ollama_rs::{generation::completion::request::GenerationRequest, Ollama},
Model,
};

const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
const DEFAULT_OLLAMA_PORT: u16 = 11434;

/// Some models such as small embedding models, are hardcoded into the node.
const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"];

/// Prompt to be used to see Ollama performance.
const TEST_PROMPT: &str = "Please write a poem about Kapadokya.";

/// Ollama-specific configurations.
#[derive(Debug, Clone)]
pub struct OllamaConfig {
Expand Down Expand Up @@ -66,12 +72,13 @@ impl OllamaConfig {
pub async fn check(
&self,
external_models: Vec<Model>,
test_workflow_timeout: Duration,
timeout: Duration,
min_tps: f64,
) -> Result<Vec<Model>, String> {
log::info!(
"Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)",
if self.auto_pull { "on" } else { "off" },
test_workflow_timeout.as_secs()
timeout.as_secs()
);

let ollama = Ollama::new(&self.host, self.port);
Expand Down Expand Up @@ -108,7 +115,7 @@ impl OllamaConfig {
}

if self
.test_workflow(model.clone(), test_workflow_timeout)
.test_performance(&ollama, &model, timeout, min_tps)
.await
{
good_models.push(model);
Expand Down Expand Up @@ -149,71 +156,48 @@ impl OllamaConfig {
///
/// This is to see if a given system can execute Ollama workflows for their chosen models,
/// e.g. if they have enough RAM/CPU and such.
pub async fn test_workflow(&self, model: Model, timeout: Duration) -> bool {
// this is the test workflow that we will run
// TODO: when Workflow's have `Clone`, we can remove the repetitive parsing here
let workflow = serde_json::from_value::<Workflow>(serde_json::json!({
"name": "Simple",
"description": "This is a simple workflow",
"config":{
"max_steps": 5,
"max_time": 100,
"max_tokens": 100,
"tools": []
},
"tasks":[
{
"id": "A",
"name": "Random Poem",
"description": "Writes a poem about Kapadokya.",
"prompt": "Please write a poem about Kapadokya.",
"inputs":[],
"operator": "generation",
"outputs":[
{
"type": "write",
"key": "poem",
"value": "__result"
}
]
},
{
"id": "__end",
"name": "end",
"description": "End of the task",
"prompt": "End of the task",
"inputs": [],
"operator": "end",
"outputs": []
}
],
"steps":[
{
"source":"A",
"target":"end"
}
],
"return_value":{
"input":{
"type": "read",
"key": "poem"
}
}
}))
.expect("Preset workflow should be parsed");

pub async fn test_performance(
&self,
ollama: &Ollama,
model: &Model,
timeout: Duration,
min_tps: f64,
) -> bool {
log::info!("Testing model {}", model);
let executor = Executor::new_at(model.clone(), &self.host, self.port);
let mut memory = ProgramMemory::new();

// first generate a dummy embedding to load the model into memory (warm-up)
if let Err(err) = ollama
.generate_embeddings(model.to_string(), "foobar".to_string(), Default::default())
.await
{
log::error!("Failed to generate embedding for model {}: {}", model, err);
return false;
};

// then, run a sample generation with timeout and measure tps
tokio::select! {
_ = tokio::time::sleep(timeout) => {
log::warn!("Ignoring model {}: Workflow timed out", model);
},
result = executor.execute(None, workflow, &mut memory) => {
result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => {
match result {
Ok(_) => {
log::info!("Accepting model {}", model);
return true;
Ok(response) => {
let tps = (response.eval_count.unwrap_or_default() as f64)
/ (response.eval_duration.unwrap_or(1) as f64)
* 1_000_000_000f64;

if tps >= min_tps {
log::info!("Model {} passed the test with tps: {}", model, tps);
return true;

}

log::warn!(
"Ignoring model {}: tps too low ({:.3} < {:.3})",
model,
tps,
min_tps
);
}
Err(e) => {
log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);
Expand Down

0 comments on commit 263894a

Please sign in to comment.