Skip to content

Commit

Permalink
azure: Chat implementation (#2615)
Browse files Browse the repository at this point in the history
* Azure Chat implementation

* Make azure work
  • Loading branch information
spolu authored Nov 22, 2023
1 parent bb455e1 commit 8adcd73
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 42 deletions.
139 changes: 115 additions & 24 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::Tokens;
use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM};
use crate::providers::openai::{completion, embed, streamed_completion};
use crate::providers::openai::{
chat_completion, completion, embed, streamed_chat_completion, streamed_completion,
};
use crate::providers::provider::{Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::{
cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE,
Expand Down Expand Up @@ -160,7 +162,7 @@ impl AzureOpenAILLM {
assert!(self.endpoint.is_some());

Ok(format!(
"{}openai/deployments/{}/completions?api-version=2022-12-01",
"{}openai/deployments/{}/completions?api-version=2023-08-01-preview",
self.endpoint.as_ref().unwrap(),
self.deployment_id
)
Expand All @@ -170,7 +172,7 @@ impl AzureOpenAILLM {
#[allow(dead_code)]
fn chat_uri(&self) -> Result<Uri> {
Ok(format!(
"{}openai/deployments/{}/chat/completions?api-version=2023-03-15-preview",
"{}openai/deployments/{}/chat/completions?api-version=2023-08-01-preview",
self.endpoint.as_ref().unwrap(),
self.deployment_id
)
Expand Down Expand Up @@ -430,7 +432,7 @@ impl LLM for AzureOpenAILLM {

Ok(LLMGeneration {
created: utils::now(),
provider: ProviderID::OpenAI.to_string(),
provider: ProviderID::AzureOpenAI.to_string(),
model: self.model_id.clone().unwrap(),
completions: c
.choices
Expand Down Expand Up @@ -462,22 +464,113 @@ impl LLM for AzureOpenAILLM {

async fn chat(
&self,
_messages: &Vec<ChatMessage>,
_functions: &Vec<ChatFunction>,
_function_call: Option<String>,
_temperature: f32,
_top_p: Option<f32>,
_n: usize,
_stop: &Vec<String>,
_max_tokens: Option<i32>,
_presence_penalty: Option<f32>,
_frequency_penalty: Option<f32>,
_extras: Option<Value>,
_event_sender: Option<UnboundedSender<Value>>,
messages: &Vec<ChatMessage>,
functions: &Vec<ChatFunction>,
function_call: Option<String>,
temperature: f32,
top_p: Option<f32>,
n: usize,
stop: &Vec<String>,
mut max_tokens: Option<i32>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
extras: Option<Value>,
event_sender: Option<UnboundedSender<Value>>,
) -> Result<LLMChatGeneration> {
Err(anyhow!(
"Chat capabilties are not implemented for provider `azure_openai`"
))
if let Some(m) = max_tokens {
if m == -1 {
max_tokens = None;
}
}

let c = match event_sender {
Some(_) => {
streamed_chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
None,
None,
messages,
functions,
function_call,
temperature,
match top_p {
Some(t) => t,
None => 1.0,
},
n,
stop,
max_tokens,
match presence_penalty {
Some(p) => p,
None => 0.0,
},
match frequency_penalty {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
event_sender,
)
.await?
}
None => {
chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
None,
None,
messages,
functions,
function_call,
temperature,
match top_p {
Some(t) => t,
None => 1.0,
},
n,
stop,
max_tokens,
match presence_penalty {
Some(p) => p,
None => 0.0,
},
match frequency_penalty {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
)
.await?
}
};

// println!("COMPLETION: {:?}", c);

assert!(c.choices.len() > 0);

Ok(LLMChatGeneration {
created: utils::now(),
provider: ProviderID::AzureOpenAI.to_string(),
model: self.model_id.clone().unwrap(),
completions: c
.choices
.iter()
.map(|c| c.message.clone())
.collect::<Vec<_>>(),
})
}
}

Expand All @@ -502,7 +595,7 @@ impl AzureOpenAIEmbedder {
assert!(self.endpoint.is_some());

Ok(format!(
"{}openai/deployments/{}/embeddings?api-version=2022-12-01",
"{}openai/deployments/{}/embeddings?api-version=2023-08-01-preview",
self.endpoint.as_ref().unwrap(),
self.deployment_id
)
Expand Down Expand Up @@ -597,13 +690,11 @@ impl Embedder for AzureOpenAIEmbedder {
}

async fn encode(&self, text: &str) -> Result<Vec<usize>> {
let tokens = { self.tokenizer().lock().encode_with_special_tokens(text) };
Ok(tokens)
encode_async(self.tokenizer(), text).await
}

async fn decode(&self, tokens: Vec<usize>) -> Result<String> {
let str = { self.tokenizer().lock().decode(tokens)? };
Ok(str)
decode_async(self.tokenizer(), tokens).await
}

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Expand Down
4 changes: 4 additions & 0 deletions front/lib/api/credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ const {
DUST_MANAGED_OPENAI_API_KEY = "",
DUST_MANAGED_ANTHROPIC_API_KEY = "",
DUST_MANAGED_TEXTSYNTH_API_KEY = "",
DUST_MANAGED_AZURE_OPENAI_API_KEY = "",
DUST_MANAGED_AZURE_OPENAI_ENDPOINT = "",
} = process.env;

export const credentialsFromProviders = (
Expand Down Expand Up @@ -56,5 +58,7 @@ export const dustManagedCredentials = (): CredentialsType => {
OPENAI_API_KEY: DUST_MANAGED_OPENAI_API_KEY,
ANTHROPIC_API_KEY: DUST_MANAGED_ANTHROPIC_API_KEY,
TEXTSYNTH_API_KEY: DUST_MANAGED_TEXTSYNTH_API_KEY,
AZURE_OPENAI_API_KEY: DUST_MANAGED_AZURE_OPENAI_API_KEY,
AZURE_OPENAI_ENDPOINT: DUST_MANAGED_AZURE_OPENAI_ENDPOINT,
};
};
18 changes: 1 addition & 17 deletions front/lib/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export const modelProviders: ModelProvider[] = [
name: "Azure OpenAI",
built: true,
enabled: false,
chat: false,
chat: true,
embed: true,
},
{
Expand All @@ -61,22 +61,6 @@ export const modelProviders: ModelProvider[] = [
chat: true,
embed: false,
},
{
providerId: "hugging_face",
name: "Hugging Face",
built: false,
enabled: false,
chat: false,
embed: false,
},
{
providerId: "replicate",
name: "Replicate",
built: false,
enabled: false,
chat: false,
embed: false,
},
];

type ServiceProvider = {
Expand Down
2 changes: 1 addition & 1 deletion front/pages/api/w/[wId]/data_sources/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async function handler(
api_error: {
type: "data_source_auth_error",
message:
"Only the users that are `admins` for the current workspace can create a managed data source.",
"Only the users that are `admins` for the current workspace can create a data source.",
},
});
}
Expand Down

0 comments on commit 8adcd73

Please sign in to comment.