From d18710ce75f3dc459df969bb1988073bb0c1706c Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 18 Oct 2024 15:51:15 +0300 Subject: [PATCH] feat: Gemini (#135) * gemini setup for api check * docs and tests, todo small fix * fixed api checks, better tests * bump version --- .env.example | 3 + Cargo.lock | 221 +++++++++++++++++++++++++----- Cargo.toml | 2 +- workflows/README.md | 9 +- workflows/src/apis/jina.rs | 3 +- workflows/src/apis/serper.rs | 3 +- workflows/src/config.rs | 39 ++++-- workflows/src/providers/gemini.rs | 139 +++++++++++++++++++ workflows/src/providers/mod.rs | 3 + workflows/src/providers/openai.rs | 32 +++-- 10 files changed, 387 insertions(+), 67 deletions(-) create mode 100644 workflows/src/providers/gemini.rs diff --git a/.env.example b/.env.example index 773c10a..475f51c 100644 --- a/.env.example +++ b/.env.example @@ -25,6 +25,9 @@ DKN_BOOTSTRAP_NODES= ## Open AI (if used, required) ## OPENAI_API_KEY= +## Gemini (if used, required) ## +GEMINI_API_KEY= + ## Ollama (if used, optional) ## # do not change this, it is used by Docker OLLAMA_HOST=http://host.docker.internal diff --git a/Cargo.lock b/Cargo.lock index b255acf..c17ef71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,6 +80,21 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.15" @@ -530,6 +545,12 @@ dependencies = [ "serde", ] +[[package]] +name = "cargo-husky" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" + [[package]] name = "cc" version = "1.1.30" @@ -575,6 +596,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.6", +] + [[package]] name = "cipher" version = "0.4.4" @@ -980,7 +1015,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.16" +version = "0.2.17" dependencies = [ "async-trait", "base64 0.22.1", @@ -1012,7 +1047,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.2.16" +version = "0.2.17" dependencies = [ "env_logger 0.11.5", "eyre", @@ -1024,7 +1059,7 @@ dependencies = [ [[package]] name = "dkn-workflows" -version = "0.2.16" +version = "0.2.17" dependencies = [ "dotenvy", "env_logger 0.11.5", @@ -1165,6 +1200,19 @@ dependencies = [ "termcolor", ] +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "env_logger" version = "0.11.5" @@ -1485,6 +1533,25 @@ dependencies = [ "byteorder", ] +[[package]] +name = "gem-rs" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e69a0d5679f63170c28494d7c1e955d0d179605fb88baab6db20ae0e4bcd85" +dependencies = [ + "base64 0.22.1", + "chrono", + "futures", + "log", + "pretty_env_logger", + "reqwest 0.12.8", + "reqwest-streams", + "serde", + "serde_json", + "sha256", + "tokio 1.40.0", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1952,9 +2019,9 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes 1.7.2", "futures-channel", @@ -1976,9 +2043,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes 1.7.2", "futures-channel", @@ -2003,7 +2070,7 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", "rustls", "rustls-native-certs", @@ -2034,7 +2101,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes 1.7.2", - "hyper 0.14.30", + "hyper 0.14.31", "native-tls", "tokio 1.40.0", "tokio-native-tls", @@ -2048,7 +2115,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes 1.7.2", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", "native-tls", "tokio 1.40.0", @@ -2067,7 +2134,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite 0.2.14", "socket2 0.5.7", "tokio 1.40.0", @@ -2075,6 +2142,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core 0.52.0", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2141,7 +2231,7 @@ dependencies = [ "bytes 1.7.2", "futures", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rand 0.8.5", "tokio 1.40.0", @@ -2220,6 +2310,17 @@ version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2318,9 +2419,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libp2p" @@ -3093,7 +3194,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", "log", "rand 0.8.5", @@ -3372,12 +3473,13 @@ dependencies = [ [[package]] name = "ollama-workflows" version = "0.1.0" -source = "git+https://github.com/andthattoo/ollama-workflows#2c70764b0f040a78e622811975bcb5c48e056341" +source = "git+https://github.com/andthattoo/ollama-workflows#f1639c9c1efe4454e98f291ce507f46ec6f1d5c8" dependencies = [ "async-trait", "colored", "dotenv", "env_logger 0.9.3", + "gem-rs", "html2text", "langchain-rust", "log", @@ -3426,9 +3528,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if 1.0.0", @@ -3467,9 +3569,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -3792,6 +3894,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "pretty_env_logger" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c" +dependencies = [ + "env_logger 0.10.2", + "log", +] + [[package]] name = "proc-macro-hack" version = "0.5.20+deprecated" @@ -3800,9 +3912,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" dependencies = [ "unicode-ident", ] @@ -3832,9 +3944,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.12.1" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666f0f59e259aea2d72e6012290c09877a780935cc3c18b1ceded41f3890d59c" +checksum = "f86ba2052aebccc42cbbb3ed234b8b13ce76f75c3551a303cb2bcffcff12bb14" dependencies = [ "bitflags 2.6.0", "memchr", @@ -4120,7 +4232,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-tls 0.5.0", "ipnet", "js-sys", @@ -4161,7 +4273,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", @@ -4214,6 +4326,23 @@ dependencies = [ "thiserror", ] +[[package]] +name = "reqwest-streams" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee84cc47a7a0ac7562173c8f421c058e4c72089d6e662f32e2cb4bcc8e6e9201" +dependencies = [ + "async-trait", + "bytes 1.7.2", + "cargo-husky", + "futures", + "reqwest 0.12.8", + "serde", + "serde_json", + "tokio 1.40.0", + "tokio-util 0.7.12", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -4320,9 +4449,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "once_cell", "ring 0.17.8", @@ -4365,9 +4494,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -4599,9 +4728,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "6dbcf9b78a125ee667ae19388837dd12294b858d101fdd393cb9d5501ef09eb2" dependencies = [ "itoa 1.0.11", "memchr", @@ -4664,6 +4793,19 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha256" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18278f6a914fa3070aa316493f7d2ddfb9ac86ebc06fa3b83bffda487e9065b0" +dependencies = [ + "async-trait", + "bytes 1.7.2", + "hex", + "sha2 0.10.8", + "tokio 1.40.0", +] + [[package]] name = "sha3" version = "0.10.8" @@ -5425,9 +5567,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom 0.2.15", ] @@ -5643,7 +5785,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca229916c5ee38c2f2bc1e9d8f04df975b4bd93f9955dc69fabb5d91270045c9" dependencies = [ - "windows-core", + "windows-core 0.51.1", "windows-targets 0.48.5", ] @@ -5656,6 +5798,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-registry" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 5a54e10..79cbd5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ default-members = ["compute"] [workspace.package] edition = "2021" -version = "0.2.16" +version = "0.2.17" license = "Apache-2.0" readme = "README.md" diff --git a/workflows/README.md b/workflows/README.md index 7c873c2..03dd841 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -3,8 +3,10 @@ We make use of Ollama Workflows in DKN; however, we also want to make sure that the chosen models are valid and is performant enough (i.e. have enough TPS). This crate handles the configurations of models to be used, and implements various service checks. -- **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models. -- **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid. +There are two types of services: + +- [`providers`](./src/providers/): these provide models that are directly used as `Model` enums in Workflows; they are only checked if a model that belongs to them is used. +- [`apis`](./src/apis/): these provide additional services used by workflows; they are only checked if their API key exists. ## Installation @@ -20,10 +22,11 @@ Note that the underlying [Ollama Workflows](https://github.com/andthattoo/ollama DKN Workflows make use of several environment variables, respecting the providers. -- `OPENAI_API_KEY` is used for OpenAI requests - `OLLAMA_HOST` is used to connect to Ollama server - `OLLAMA_PORT` is used to connect to Ollama server - `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not +- `OPENAI_API_KEY` is used for OpenAI requests +- `GEMINI_API_KEY` is used for Gemini requests - `SERPER_API_KEY` is optional API key to use **Serper**, for better Workflow executions - `JINA_API_KEY` is optional API key to use **Jina**, for better Workflow executions diff --git a/workflows/src/apis/jina.rs b/workflows/src/apis/jina.rs index 3ff0a71..362f0a4 100644 --- a/workflows/src/apis/jina.rs +++ b/workflows/src/apis/jina.rs @@ -79,7 +79,8 @@ mod tests { #[ignore = "requires Jina API key"] async fn test_jina_check() { let _ = dotenvy::dotenv(); - assert!(env::var(ENV_VAR_NAME).is_ok()); + assert!(env::var(ENV_VAR_NAME).is_ok(), "should have api key"); + let res = JinaConfig::new().check_optional().await; assert!(res.is_ok(), "should pass with api key"); diff --git a/workflows/src/apis/serper.rs b/workflows/src/apis/serper.rs index 0d916ec..0612160 100644 --- a/workflows/src/apis/serper.rs +++ b/workflows/src/apis/serper.rs @@ -86,7 +86,8 @@ mod tests { #[ignore = "requires Serper API key"] async fn test_serper_check() { let _ = dotenvy::dotenv(); - assert!(env::var(ENV_VAR_NAME).is_ok()); + assert!(env::var(ENV_VAR_NAME).is_ok(), "should have api key"); + let res = SerperConfig::new().check_optional().await; assert!(res.is_ok(), "should pass with api key"); diff --git a/workflows/src/config.rs b/workflows/src/config.rs index 2124088..da528d1 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -1,6 +1,6 @@ use crate::{ apis::{JinaConfig, SerperConfig}, - providers::{OllamaConfig, OpenAIConfig}, + providers::{GeminiConfig, OllamaConfig, OpenAIConfig}, split_csv_line, Model, ModelProvider, }; use eyre::{eyre, Result}; @@ -16,6 +16,9 @@ pub struct DriaWorkflowsConfig { /// OpenAI configurations, e.g. API key, in case OpenAI is used. /// Otherwise, can be ignored. pub openai: OpenAIConfig, + /// Gemini configurations, e.g. API key, in case Gemini is used. + /// Otherwise, can be ignored. + pub gemini: GeminiConfig, /// Serper configurations, e.g. API key, in case Serper is given in environment. /// Otherwise, can be ignored. pub serper: SerperConfig, @@ -34,8 +37,9 @@ impl DriaWorkflowsConfig { Self { models: models_and_providers, - openai: OpenAIConfig::new(), ollama: OllamaConfig::new(), + openai: OpenAIConfig::new(), + gemini: GeminiConfig::new(), serper: SerperConfig::new(), jina: JinaConfig::new(), } @@ -192,29 +196,40 @@ impl DriaWorkflowsConfig { // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them) if unique_providers.contains(&ModelProvider::Ollama) { - let ollama_models = self.get_models_for_provider(ModelProvider::Ollama); - - // ensure that the models are pulled / pull them if not - let good_ollama_models = self.ollama.check(ollama_models).await?; + let provider_models = self.get_models_for_provider(ModelProvider::Ollama); good_models.extend( - good_ollama_models + self.ollama + .check(provider_models) + .await? .into_iter() .map(|m| (ModelProvider::Ollama, m)), ); } - // if OpenAI is a provider, check that the API key is set + // if OpenAI is a provider, check that the API key is set & models are available if unique_providers.contains(&ModelProvider::OpenAI) { - let openai_models = self.get_models_for_provider(ModelProvider::OpenAI); - - let good_openai_models = self.openai.check(openai_models).await?; + let provider_models = self.get_models_for_provider(ModelProvider::OpenAI); good_models.extend( - good_openai_models + self.openai + .check(provider_models) + .await? .into_iter() .map(|m| (ModelProvider::OpenAI, m)), ); } + // if Gemini is a provider, check that the API key is set & models are available + if unique_providers.contains(&ModelProvider::Gemini) { + let provider_models = self.get_models_for_provider(ModelProvider::Gemini); + good_models.extend( + self.gemini + .check(provider_models) + .await? + .into_iter() + .map(|m| (ModelProvider::Gemini, m)), + ); + } + // update good models if good_models.is_empty() { Err(eyre!("No good models found, please check logs for errors.")) diff --git a/workflows/src/providers/gemini.rs b/workflows/src/providers/gemini.rs new file mode 100644 index 0000000..de7be78 --- /dev/null +++ b/workflows/src/providers/gemini.rs @@ -0,0 +1,139 @@ +use eyre::{eyre, Context, Result}; +use ollama_workflows::Model; +use reqwest::Client; +use serde::Deserialize; +use std::env; + +use crate::utils::safe_read_env; + +/// [`models.list`](https://ai.google.dev/api/models#method:-models.list) endpoint +const GEMINI_MODELS_API: &str = "https://generativelanguage.googleapis.com/v1beta/models"; +const ENV_VAR_NAME: &str = "GEMINI_API_KEY"; + +/// [Model](https://ai.google.dev/api/models#Model) API object, fields omitted. +#[derive(Debug, Clone, Deserialize)] +#[allow(non_snake_case)] +#[allow(unused)] +struct GeminiModel { + name: String, + version: String, + // other fields are ignored from API response +} + +#[derive(Debug, Clone, Deserialize)] +#[allow(non_snake_case)] +#[allow(unused)] +struct GeminiModelsResponse { + models: Vec, +} + +/// OpenAI-specific configurations. +#[derive(Debug, Clone, Default)] +pub struct GeminiConfig { + /// API key, if available. + api_key: Option, +} + +impl GeminiConfig { + /// Looks at the environment variables for Gemini API key. + pub fn new() -> Self { + Self { + api_key: safe_read_env(env::var(ENV_VAR_NAME)), + } + } + + /// Sets the API key for Gemini. + pub fn with_api_key(mut self, api_key: String) -> Self { + self.api_key = Some(api_key); + self + } + + /// Check if requested models exist & are available in the OpenAI account. + pub async fn check(&self, models: Vec) -> Result> { + log::info!("Checking Gemini requirements"); + + // check API key + let Some(api_key) = &self.api_key else { + return Err(eyre!("Gemini API key not found")); + }; + + // fetch models + let client = Client::new(); + let request = client + .get(GEMINI_MODELS_API) + .query(&[("key", api_key)]) + .build() + .wrap_err("failed to build request")?; + + let response = client + .execute(request) + .await + .wrap_err("failed to send request")?; + + // parse response + if response.status().is_client_error() { + return Err(eyre!( + "Failed to fetch Gemini models:\n{}", + response.text().await.unwrap_or_default() + )); + } + let gemini_models = response.json::().await?; + + // check if models exist and select those that are available + let mut available_models = Vec::new(); + for requested_model in models { + if !gemini_models.models.iter().any(|gemini_model| { + // a gemini model name in API response is given as `models/{baseModelId}-{version}` + // the model name in Workflows can include the version as well, so best bet is to check prefix + // ignoring the `models/` part + gemini_model + .name + .trim_start_matches("models/") + .starts_with(&requested_model.to_string()) + }) { + log::warn!( + "Model {} not found in your Gemini account, ignoring it.", + requested_model + ); + } else { + available_models.push(requested_model); + } + } + + log::info!( + "Gemini checks are finished, using models: {:#?}", + available_models + ); + + Ok(available_models) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "requires Gemini API key"] + async fn test_gemini_check() { + let _ = dotenvy::dotenv(); // read api key + assert!(env::var(ENV_VAR_NAME).is_ok(), "should have api key"); + + let models = vec![ + Model::Gemini10Pro, + Model::Gemini15ProExp0827, + Model::Gemini15Flash, + Model::Gemini15Pro, + ]; + let res = GeminiConfig::new().check(models.clone()).await; + assert_eq!(res.unwrap(), models); + + env::set_var(ENV_VAR_NAME, "i-dont-work"); + let res = GeminiConfig::new().check(vec![]).await; + assert!(res.is_err()); + + env::remove_var(ENV_VAR_NAME); + let res = GeminiConfig::new().check(vec![]).await; + assert!(res.is_err()); + } +} diff --git a/workflows/src/providers/mod.rs b/workflows/src/providers/mod.rs index a6ea768..46fa08a 100644 --- a/workflows/src/providers/mod.rs +++ b/workflows/src/providers/mod.rs @@ -3,3 +3,6 @@ pub use ollama::OllamaConfig; mod openai; pub use openai::OpenAIConfig; + +mod gemini; +pub use gemini::GeminiConfig; diff --git a/workflows/src/providers/openai.rs b/workflows/src/providers/openai.rs index c61390a..c72fe42 100644 --- a/workflows/src/providers/openai.rs +++ b/workflows/src/providers/openai.rs @@ -2,33 +2,23 @@ use eyre::{eyre, Context, Result}; use ollama_workflows::Model; use reqwest::Client; use serde::Deserialize; +use std::env; use crate::utils::safe_read_env; const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models"; const ENV_VAR_NAME: &str = "OPENAI_API_KEY"; -/// [Model](https://platform.openai.com/docs/api-reference/models/object) API object. +/// [Model](https://platform.openai.com/docs/api-reference/models/object) API object, fields omitted. #[derive(Debug, Clone, Deserialize)] struct OpenAIModel { /// The model identifier, which can be referenced in the API endpoints. id: String, - /// The Unix timestamp (in seconds) when the model was created. - #[allow(unused)] - created: u64, - /// The object type, which is always "model". - #[allow(unused)] - object: String, - /// The organization that owns the model. - #[allow(unused)] - owned_by: String, } #[derive(Debug, Clone, Deserialize)] struct OpenAIModelsResponse { data: Vec, - #[allow(unused)] - object: String, } /// OpenAI-specific configurations. @@ -42,7 +32,7 @@ impl OpenAIConfig { /// Looks at the environment variables for OpenAI API key. pub fn new() -> Self { Self { - api_key: safe_read_env(std::env::var(ENV_VAR_NAME)), + api_key: safe_read_env(env::var(ENV_VAR_NAME)), } } @@ -115,8 +105,22 @@ mod tests { #[tokio::test] #[ignore = "requires OpenAI API key"] async fn test_openai_check() { + let _ = dotenvy::dotenv(); // read api key + assert!(env::var(ENV_VAR_NAME).is_ok(), "should have api key"); + + let models = vec![Model::GPT4Turbo, Model::GPT4o, Model::GPT4oMini]; + let config = OpenAIConfig::new(); + let res = config.check(models.clone()).await; + assert_eq!(res.unwrap(), models); + + env::set_var(ENV_VAR_NAME, "i-dont-work"); + let config = OpenAIConfig::new(); + let res = config.check(vec![]).await; + assert!(res.is_err()); + + env::remove_var(ENV_VAR_NAME); let config = OpenAIConfig::new(); let res = config.check(vec![]).await; - println!("Result: {}", res.unwrap_err()); + assert!(res.is_err()); } }