From 8adcd7377be36db28841b85be6ebe7134a1c244d Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Wed, 22 Nov 2023 12:07:48 +0100 Subject: [PATCH] azure: Chat implementation (#2615) * Azure Chat implementation * Make azure work --- core/src/providers/azure_openai.rs | 139 +++++++++++++++--- front/lib/api/credentials.ts | 4 + front/lib/providers.ts | 18 +-- front/pages/api/w/[wId]/data_sources/index.ts | 2 +- 4 files changed, 121 insertions(+), 42 deletions(-) diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index a4c158e029b7..0c85114e01fb 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -2,7 +2,9 @@ use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::Tokens; use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM}; -use crate::providers::openai::{completion, embed, streamed_completion}; +use crate::providers::openai::{ + chat_completion, completion, embed, streamed_chat_completion, streamed_completion, +}; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{ cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE, @@ -160,7 +162,7 @@ impl AzureOpenAILLM { assert!(self.endpoint.is_some()); Ok(format!( - "{}openai/deployments/{}/completions?api-version=2022-12-01", + "{}openai/deployments/{}/completions?api-version=2023-08-01-preview", self.endpoint.as_ref().unwrap(), self.deployment_id ) @@ -170,7 +172,7 @@ impl AzureOpenAILLM { #[allow(dead_code)] fn chat_uri(&self) -> Result { Ok(format!( - "{}openai/deployments/{}/chat/completions?api-version=2023-03-15-preview", + "{}openai/deployments/{}/chat/completions?api-version=2023-08-01-preview", self.endpoint.as_ref().unwrap(), self.deployment_id ) @@ -430,7 +432,7 @@ impl LLM for AzureOpenAILLM { Ok(LLMGeneration { created: utils::now(), - provider: ProviderID::OpenAI.to_string(), + provider: ProviderID::AzureOpenAI.to_string(), model: self.model_id.clone().unwrap(), completions: c .choices @@ -462,22 +464,113 @@ impl LLM for AzureOpenAILLM { async fn chat( &self, - _messages: &Vec, - _functions: &Vec, - _function_call: Option, - _temperature: f32, - _top_p: Option, - _n: usize, - _stop: &Vec, - _max_tokens: Option, - _presence_penalty: Option, - _frequency_penalty: Option, - _extras: Option, - _event_sender: Option>, + messages: &Vec, + functions: &Vec, + function_call: Option, + temperature: f32, + top_p: Option, + n: usize, + stop: &Vec, + mut max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + extras: Option, + event_sender: Option>, ) -> Result { - Err(anyhow!( - "Chat capabilties are not implemented for provider `azure_openai`" - )) + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = None; + } + } + + let c = match event_sender { + Some(_) => { + streamed_chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + None, + messages, + functions, + function_call, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + match &extras { + Some(e) => match e.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + None => None, + }, + event_sender, + ) + .await? + } + None => { + chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + None, + messages, + functions, + function_call, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + match &extras { + Some(e) => match e.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + None => None, + }, + ) + .await? + } + }; + + // println!("COMPLETION: {:?}", c); + + assert!(c.choices.len() > 0); + + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::AzureOpenAI.to_string(), + model: self.model_id.clone().unwrap(), + completions: c + .choices + .iter() + .map(|c| c.message.clone()) + .collect::>(), + }) } } @@ -502,7 +595,7 @@ impl AzureOpenAIEmbedder { assert!(self.endpoint.is_some()); Ok(format!( - "{}openai/deployments/{}/embeddings?api-version=2022-12-01", + "{}openai/deployments/{}/embeddings?api-version=2023-08-01-preview", self.endpoint.as_ref().unwrap(), self.deployment_id ) @@ -597,13 +690,11 @@ impl Embedder for AzureOpenAIEmbedder { } async fn encode(&self, text: &str) -> Result> { - let tokens = { self.tokenizer().lock().encode_with_special_tokens(text) }; - Ok(tokens) + encode_async(self.tokenizer(), text).await } async fn decode(&self, tokens: Vec) -> Result { - let str = { self.tokenizer().lock().decode(tokens)? }; - Ok(str) + decode_async(self.tokenizer(), tokens).await } async fn embed(&self, text: Vec<&str>, extras: Option) -> Result> { diff --git a/front/lib/api/credentials.ts b/front/lib/api/credentials.ts index f50f3cdc2632..ba1c56eba980 100644 --- a/front/lib/api/credentials.ts +++ b/front/lib/api/credentials.ts @@ -4,6 +4,8 @@ const { DUST_MANAGED_OPENAI_API_KEY = "", DUST_MANAGED_ANTHROPIC_API_KEY = "", DUST_MANAGED_TEXTSYNTH_API_KEY = "", + DUST_MANAGED_AZURE_OPENAI_API_KEY = "", + DUST_MANAGED_AZURE_OPENAI_ENDPOINT = "", } = process.env; export const credentialsFromProviders = ( @@ -56,5 +58,7 @@ export const dustManagedCredentials = (): CredentialsType => { OPENAI_API_KEY: DUST_MANAGED_OPENAI_API_KEY, ANTHROPIC_API_KEY: DUST_MANAGED_ANTHROPIC_API_KEY, TEXTSYNTH_API_KEY: DUST_MANAGED_TEXTSYNTH_API_KEY, + AZURE_OPENAI_API_KEY: DUST_MANAGED_AZURE_OPENAI_API_KEY, + AZURE_OPENAI_ENDPOINT: DUST_MANAGED_AZURE_OPENAI_ENDPOINT, }; }; diff --git a/front/lib/providers.ts b/front/lib/providers.ts index 2cd04279c2b0..8ce905618f62 100644 --- a/front/lib/providers.ts +++ b/front/lib/providers.ts @@ -42,7 +42,7 @@ export const modelProviders: ModelProvider[] = [ name: "Azure OpenAI", built: true, enabled: false, - chat: false, + chat: true, embed: true, }, { @@ -61,22 +61,6 @@ export const modelProviders: ModelProvider[] = [ chat: true, embed: false, }, - { - providerId: "hugging_face", - name: "Hugging Face", - built: false, - enabled: false, - chat: false, - embed: false, - }, - { - providerId: "replicate", - name: "Replicate", - built: false, - enabled: false, - chat: false, - embed: false, - }, ]; type ServiceProvider = { diff --git a/front/pages/api/w/[wId]/data_sources/index.ts b/front/pages/api/w/[wId]/data_sources/index.ts index 70aaa7e2fddd..00e627dd49b0 100644 --- a/front/pages/api/w/[wId]/data_sources/index.ts +++ b/front/pages/api/w/[wId]/data_sources/index.ts @@ -60,7 +60,7 @@ async function handler( api_error: { type: "data_source_auth_error", message: - "Only the users that are `admins` for the current workspace can create a managed data source.", + "Only the users that are `admins` for the current workspace can create a data source.", }, }); }