From d6ff15de1b77ca775fbccdfe3ed6e163d278b5e8 Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 19 Nov 2024 11:50:46 +0300 Subject: [PATCH] feat: added openrouter support (#146) * added openrouter * smol update * added blacklisting on wrong protocol --- .env.example | 3 + Cargo.lock | 113 ++++++++++----------- Cargo.toml | 2 +- Dockerfile | 2 +- compute/src/utils/message.rs | 6 ++ p2p/src/client.rs | 8 ++ p2p/src/lib.rs | 4 +- workflows/src/config.rs | 18 +++- workflows/src/providers/mod.rs | 3 + workflows/src/providers/openrouter.rs | 140 ++++++++++++++++++++++++++ 10 files changed, 235 insertions(+), 64 deletions(-) create mode 100644 workflows/src/providers/openrouter.rs diff --git a/.env.example b/.env.example index 475f51c..2643565 100644 --- a/.env.example +++ b/.env.example @@ -28,6 +28,9 @@ OPENAI_API_KEY= ## Gemini (if used, required) ## GEMINI_API_KEY= +## Open Router (if used, required) ## +OPENROUTER_API_KEY= + ## Ollama (if used, optional) ## # do not change this, it is used by Docker OLLAMA_HOST=http://host.docker.internal diff --git a/Cargo.lock b/Cargo.lock index 17b6a5a..91adc68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,7 +913,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.21" +version = "0.2.22" dependencies = [ "async-trait", "base64 0.22.1", @@ -945,7 +945,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.2.21" +version = "0.2.22" dependencies = [ "env_logger 0.11.5", "eyre", @@ -957,7 +957,7 @@ dependencies = [ [[package]] name = "dkn-workflows" -version = "0.2.21" +version = "0.2.22" dependencies = [ "dotenvy", "env_logger 0.11.5", @@ -2165,9 +2165,9 @@ dependencies = [ [[package]] name = "if-watch" -version = "3.2.0" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b0422c86d7ce0e97169cc42e04ae643caf278874a7a3c87b8150a220dc7e1e" +checksum = "cdf9d64cfcf380606e64f9a0bcf493616b65331199f984151a6fa11a7b3cde38" dependencies = [ "async-io", "core-foundation", @@ -2176,8 +2176,12 @@ dependencies = [ "if-addrs", "ipnet", "log", + "netlink-packet-core", + "netlink-packet-route", + "netlink-proto", + "netlink-sys", "rtnetlink", - "system-configuration 0.5.1", + "system-configuration", "tokio 1.41.1", "windows", ] @@ -2841,7 +2845,7 @@ dependencies = [ "thiserror 1.0.69", "tracing", "yamux 0.12.1", - "yamux 0.13.3", + "yamux 0.13.4", ] [[package]] @@ -3171,21 +3175,20 @@ dependencies = [ [[package]] name = "netlink-packet-core" -version = "0.4.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345b8ab5bd4e71a2986663e88c56856699d060e78e152e6e9d7966fcd5491297" +checksum = "72724faf704479d67b388da142b186f916188505e7e0b26719019c525882eda4" dependencies = [ "anyhow", "byteorder", - "libc", "netlink-packet-utils", ] [[package]] name = "netlink-packet-route" -version = "0.12.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9ea4302b9759a7a88242299225ea3688e63c85ea136371bb6cf94fd674efaab" +checksum = "053998cea5a306971f88580d0829e90f270f940befd7cf928da179d4187a5a66" dependencies = [ "anyhow", "bitflags 1.3.2", @@ -3209,9 +3212,9 @@ dependencies = [ [[package]] name = "netlink-proto" -version = "0.10.0" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65b4b14489ab424703c092062176d52ba55485a89c076b4f9db05092b7223aa6" +checksum = "86b33524dc0968bfad349684447bfce6db937a9ac3332a1fe60c0c5a5ce63f21" dependencies = [ "bytes 1.8.0", "futures", @@ -3243,9 +3246,9 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" [[package]] name = "nix" -version = "0.24.3" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" dependencies = [ "bitflags 1.3.2", "cfg-if 1.0.0", @@ -3356,7 +3359,7 @@ dependencies = [ [[package]] name = "ollama-workflows" version = "0.1.0" -source = "git+https://github.com/andthattoo/ollama-workflows#75ead48d237d1f408a82f20eef2cd350f217e592" +source = "git+https://github.com/andthattoo/ollama-workflows#12f622c1532ff167a4e11d8504f35f2f209b9312" dependencies = [ "async-trait", "base64 0.22.1", @@ -4116,7 +4119,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper", - "system-configuration 0.6.1", + "system-configuration", "tokio 1.41.1", "tokio-native-tls", "tokio-rustls", @@ -4190,14 +4193,17 @@ dependencies = [ [[package]] name = "rtnetlink" -version = "0.10.1" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322c53fd76a18698f1c27381d58091de3a043d356aa5bd0d510608b565f469a0" +checksum = "7a552eb82d19f38c3beed3f786bd23aa434ceb9ac43ab44419ca6d67a7e186c0" dependencies = [ "futures", "log", + "netlink-packet-core", "netlink-packet-route", + "netlink-packet-utils", "netlink-proto", + "netlink-sys", "nix", "thiserror 1.0.69", "tokio 1.41.1", @@ -4235,9 +4241,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -4482,9 +4488,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa 1.0.11", "memchr", @@ -4798,17 +4804,6 @@ dependencies = [ "syn 2.0.87", ] -[[package]] -name = "system-configuration" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "system-configuration-sys 0.5.0", -] - [[package]] name = "system-configuration" version = "0.6.1" @@ -4817,17 +4812,7 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.6.0", "core-foundation", - "system-configuration-sys 0.6.0", -] - -[[package]] -name = "system-configuration-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" -dependencies = [ - "core-foundation-sys", - "libc", + "system-configuration-sys", ] [[package]] @@ -5505,29 +5490,30 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.51.1" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca229916c5ee38c2f2bc1e9d8f04df975b4bd93f9955dc69fabb5d91270045c9" +checksum = "efc5cf48f83140dcaab716eeaea345f9e93d0018fb81162753a3f76c3397b538" dependencies = [ - "windows-core 0.51.1", - "windows-targets 0.48.5", + "windows-core 0.53.0", + "windows-targets 0.52.6", ] [[package]] name = "windows-core" -version = "0.51.1" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] name = "windows-core" -version = "0.52.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "9dcc5b895a6377f1ab9fa55acedab1fd5ac0db66ad1e6c7f47e28a22e446a5dd" dependencies = [ + "windows-result 0.1.2", "windows-targets 0.52.6", ] @@ -5537,11 +5523,20 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ - "windows-result", + "windows-result 0.2.0", "windows-strings", "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-result" version = "0.2.0" @@ -5557,7 +5552,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ - "windows-result", + "windows-result 0.2.0", "windows-targets 0.52.6", ] @@ -5835,9 +5830,9 @@ dependencies = [ [[package]] name = "yamux" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028" +checksum = "17610762a1207ee816c6fadc29220904753648aba0a9ed61c7b8336e80a559c4" dependencies = [ "futures", "log", diff --git a/Cargo.toml b/Cargo.toml index feb46d3..0a5e655 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ default-members = ["compute"] [workspace.package] edition = "2021" -version = "0.2.21" +version = "0.2.22" license = "Apache-2.0" readme = "README.md" diff --git a/Dockerfile b/Dockerfile index d7c5265..6621725 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM rust:1.75 as builder +FROM --platform=$BUILDPLATFORM rust:1.75 AS builder # https://docs.docker.com/engine/reference/builder/#automatic-platform-args-in-the-global-scope # diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index f05f026..5781ac6 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -23,6 +23,9 @@ pub struct DKNMessage { /// /// NOTE: This can be obtained via Identify protocol version pub(crate) version: String, + /// Identity protocol string of the Dria Compute Node + #[serde(default)] + pub(crate) identity: String, /// The timestamp of the message, in nanoseconds /// /// NOTE: This can be obtained via DataTransform in GossipSub @@ -46,6 +49,9 @@ impl DKNMessage { payload: BASE64_STANDARD.encode(data), topic: topic.to_string(), version: DRIA_COMPUTE_NODE_VERSION.to_string(), + identity: dkn_p2p::P2P_IDENTITY_PREFIX + .trim_end_matches('/') + .to_string(), timestamp: get_current_time_nanos(), } } diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 8cc6740..7c8b32f 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -254,6 +254,7 @@ impl DriaP2PClient { log::warn!("Local node is listening on {}", address); } SwarmEvent::ExternalAddrConfirmed { address } => { + // this is usually the external address via relay log::info!("External address confirmed: {}", address); } event => log::trace!("Unhandled Swarm Event: {:?}", event), @@ -275,6 +276,13 @@ impl DriaP2PClient { info.protocol_version, self.identity_protocol ); + + // blacklist peers with different protocol + self.swarm + .behaviour_mut() + .gossipsub + .blacklist_peer(&peer_id); + return; } diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 97eacec..410edd1 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -7,10 +7,10 @@ mod client; pub use client::DriaP2PClient; /// Prefix for Kademlia protocol, must start with `/`! -pub(crate) const P2P_KADEMLIA_PREFIX: &str = "/dria/kad/"; +pub const P2P_KADEMLIA_PREFIX: &str = "/dria/kad/"; /// Prefix for Identity protocol string. -pub(crate) const P2P_IDENTITY_PREFIX: &str = "dria/"; +pub const P2P_IDENTITY_PREFIX: &str = "dria/"; // re-exports pub use libp2p; diff --git a/workflows/src/config.rs b/workflows/src/config.rs index da528d1..5a4f199 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -1,6 +1,6 @@ use crate::{ apis::{JinaConfig, SerperConfig}, - providers::{GeminiConfig, OllamaConfig, OpenAIConfig}, + providers::{GeminiConfig, OllamaConfig, OpenAIConfig, OpenRouterConfig}, split_csv_line, Model, ModelProvider, }; use eyre::{eyre, Result}; @@ -19,6 +19,9 @@ pub struct DriaWorkflowsConfig { /// Gemini configurations, e.g. API key, in case Gemini is used. /// Otherwise, can be ignored. pub gemini: GeminiConfig, + /// OpenRouter configurations, e.g. API key, in case OpenRouter is used. + /// Otherwise, can be ignored. + pub openrouter: OpenRouterConfig, /// Serper configurations, e.g. API key, in case Serper is given in environment. /// Otherwise, can be ignored. pub serper: SerperConfig, @@ -39,6 +42,7 @@ impl DriaWorkflowsConfig { models: models_and_providers, ollama: OllamaConfig::new(), openai: OpenAIConfig::new(), + openrouter: OpenRouterConfig::new(), gemini: GeminiConfig::new(), serper: SerperConfig::new(), jina: JinaConfig::new(), @@ -230,6 +234,18 @@ impl DriaWorkflowsConfig { ); } + // if OpenRouter is a provider, check that the API key is set + if unique_providers.contains(&ModelProvider::OpenRouter) { + let provider_models = self.get_models_for_provider(ModelProvider::OpenRouter); + good_models.extend( + self.openrouter + .check(provider_models) + .await? + .into_iter() + .map(|m| (ModelProvider::OpenRouter, m)), + ); + } + // update good models if good_models.is_empty() { Err(eyre!("No good models found, please check logs for errors.")) diff --git a/workflows/src/providers/mod.rs b/workflows/src/providers/mod.rs index 46fa08a..809f41d 100644 --- a/workflows/src/providers/mod.rs +++ b/workflows/src/providers/mod.rs @@ -6,3 +6,6 @@ pub use openai::OpenAIConfig; mod gemini; pub use gemini::GeminiConfig; + +mod openrouter; +pub use openrouter::OpenRouterConfig; diff --git a/workflows/src/providers/openrouter.rs b/workflows/src/providers/openrouter.rs new file mode 100644 index 0000000..c3e755d --- /dev/null +++ b/workflows/src/providers/openrouter.rs @@ -0,0 +1,140 @@ +use eyre::{eyre, Context, Result}; +use ollama_workflows::Model; +use reqwest::Client; +use std::env; + +use crate::utils::safe_read_env; + +const ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; + +/// OpenRouter-specific configurations. +#[derive(Debug, Clone, Default)] +pub struct OpenRouterConfig { + /// API key, if available. + api_key: Option, +} + +impl OpenRouterConfig { + /// Looks at the environment variables for OpenRouter API key. + pub fn new() -> Self { + Self { + api_key: safe_read_env(env::var(ENV_VAR_NAME)), + } + } + + /// Sets the API key for OpenRouter. + pub fn with_api_key(mut self, api_key: String) -> Self { + self.api_key = Some(api_key); + self + } + + /// Checks if the API key exists. + pub async fn check(&self, external_models: Vec) -> Result> { + log::info!("Checking OpenRouter API key"); + + // check API key + let Some(api_key) = &self.api_key else { + return Err(eyre!("OpenRouter API key not found")); + }; + + // make a dummy request with existing models + let mut available_models = Vec::new(); + for requested_model in external_models { + // make a dummy request + if let Err(err) = self.dummy_request(api_key.as_str(), &requested_model).await { + log::warn!( + "Model {} failed dummy request, ignoring it: {}", + requested_model, + err + ); + continue; + } + + available_models.push(requested_model) + } + + // log results + if available_models.is_empty() { + log::warn!("OpenRouter checks are finished, no available models found.",); + } else { + log::info!( + "OpenRouter checks are finished, using models: {:#?}", + available_models + ); + } + + Ok(available_models) + } + + /// Makes a dummy request to the OpenRouter API to check if the model is available & has credits. + async fn dummy_request(&self, api_key: &str, model: &Model) -> Result<()> { + log::debug!("Making a dummy request with: {}", model); + let client = Client::new(); + let request = client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {}", api_key)) + .header("Content-Type", "application/json") + .body( + serde_json::json!({ + "model": model.to_string(), + "messages": [ + { + "role": "user", + "content": "What is 2+2?" + } + ] + }) + .to_string(), + ) + .build() + .wrap_err("failed to build request")?; + + let response = client + .execute(request) + .await + .wrap_err("failed to send request")?; + + // ensure response is ok + if !response.status().is_success() { + return Err(eyre!( + "Failed to make OpenRouter chat request:\n{}", + response + .text() + .await + .unwrap_or("Could not get error text as well".to_string()) + )); + } + log::debug!("Dummy request successful for model {}", model); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "requires OpenRouter API key"] + async fn test_openai_check() { + let _ = dotenvy::dotenv(); // read api key + assert!(env::var(ENV_VAR_NAME).is_ok(), "should have api key"); + env::set_var("RUST_LOG", "none,dkn_workflows=debug"); + let _ = env_logger::try_init(); + + let models = vec![Model::GPT4Turbo, Model::GPT4o, Model::GPT4oMini]; + let config = OpenRouterConfig::new(); + let res = config.check(models.clone()).await; + assert_eq!(res.unwrap(), models); + + env::set_var(ENV_VAR_NAME, "i-dont-work"); + let config = OpenRouterConfig::new(); + let res = config.check(vec![]).await; + assert!(res.is_err()); + + env::remove_var(ENV_VAR_NAME); + let config = OpenRouterConfig::new(); + let res = config.check(vec![]).await; + assert!(res.is_err()); + } +}