Skip to content

Commit

Permalink
chore: refactor OpenAI-compatible providers (#10003)
Browse files Browse the repository at this point in the history
* chore: refactor OpenAI-compatible providers

* render provider

* support squashing text contents + limit toll call IDs

---------

Co-authored-by: Henry Fontanier <henry@dust.tt>
  • Loading branch information
fontanierh and Henry Fontanier authored Jan 16, 2025
1 parent 2cd4182 commit e5b1b68
Show file tree
Hide file tree
Showing 6 changed files with 1,776 additions and 1,889 deletions.
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub mod providers {
pub mod anthropic;
pub mod deepseek;
pub mod google_ai_studio;
pub mod openai_compatible_helpers;
pub mod togetherai;
}
pub mod http {
Expand Down
166 changes: 32 additions & 134 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use crate::providers::chat_messages::AssistantChatMessage;
use crate::providers::chat_messages::ChatMessage;
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::ChatFunction;
use crate::providers::llm::Tokens;
use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM};
use crate::providers::openai::logprobs_from_choices;
use crate::providers::openai::{
chat_completion, completion, embed, streamed_chat_completion, streamed_completion,
to_openai_messages, OpenAILLM, OpenAITool, OpenAIToolChoice,
};
use crate::providers::openai::completion;
use crate::providers::openai::embed;
use crate::providers::openai::streamed_completion;
use crate::providers::provider::{Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async};
use crate::providers::tiktoken::tiktoken::{
Expand All @@ -25,10 +22,13 @@ use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io::prelude::*;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;

use super::openai::OpenAILLM;
use super::openai_compatible_helpers::openai_compatible_chat_completion;
use super::openai_compatible_helpers::TransformSystemMessages;

#[derive(Serialize, Deserialize, Debug, Clone)]
struct AzureOpenAIScaleSettings {
scale_type: String,
Expand Down Expand Up @@ -436,140 +436,38 @@ impl LLM for AzureOpenAILLM {
top_p: Option<f32>,
n: usize,
stop: &Vec<String>,
mut max_tokens: Option<i32>,
max_tokens: Option<i32>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
logprobs: Option<bool>,
top_logprobs: Option<i32>,
extras: Option<Value>,
event_sender: Option<UnboundedSender<Value>>,
) -> Result<LLMChatGeneration> {
if let Some(m) = max_tokens {
if m == -1 {
max_tokens = None;
}
}

let (openai_user, response_format, reasoning_effort) = match &extras {
None => (None, None, None),
Some(v) => (
match v.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
match v.get("response_format") {
Some(Value::String(f)) => Some(f.to_string()),
_ => None,
},
match v.get("reasoning_effort") {
Some(Value::String(r)) => Some(r.to_string()),
_ => None,
},
),
};

let tool_choice = match function_call.as_ref() {
Some(fc) => Some(OpenAIToolChoice::from_str(fc)?),
None => None,
};

let tools = functions
.iter()
.map(OpenAITool::try_from)
.collect::<Result<Vec<OpenAITool>, _>>()?;

let openai_messages = to_openai_messages(messages, &self.model_id.clone().unwrap())?;

let (c, request_id) = match event_sender {
Some(_) => {
streamed_chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
None,
None,
&openai_messages,
tools,
tool_choice,
temperature,
match top_p {
Some(t) => t,
None => 1.0,
},
n,
stop,
max_tokens,
match presence_penalty {
Some(p) => p,
None => 0.0,
},
match frequency_penalty {
Some(f) => f,
None => 0.0,
},
response_format,
reasoning_effort,
logprobs,
top_logprobs,
openai_user,
event_sender,
)
.await?
}
None => {
chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
None,
None,
&openai_messages,
tools,
tool_choice,
temperature,
match top_p {
Some(t) => t,
None => 1.0,
},
n,
stop,
max_tokens,
match presence_penalty {
Some(p) => p,
None => 0.0,
},
match frequency_penalty {
Some(f) => f,
None => 0.0,
},
response_format,
reasoning_effort,
logprobs,
top_logprobs,
openai_user,
)
.await?
}
};

// println!("COMPLETION: {:?}", c);

assert!(c.choices.len() > 0);

Ok(LLMChatGeneration {
created: utils::now(),
provider: ProviderID::AzureOpenAI.to_string(),
model: self.model_id.clone().unwrap(),
completions: c
.choices
.iter()
.map(|c| AssistantChatMessage::try_from(&c.message))
.collect::<Result<Vec<_>>>()?,
usage: c.usage.map(|usage| LLMTokenUsage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens.unwrap_or(0),
}),
provider_request_id: request_id,
logprobs: logprobs_from_choices(&c.choices),
})
openai_compatible_chat_completion(
self.chat_uri()?,
self.model_id.clone().unwrap(),
self.api_key.clone().unwrap(),
messages,
functions,
function_call,
temperature,
top_p,
n,
stop,
max_tokens,
presence_penalty,
frequency_penalty,
logprobs,
top_logprobs,
extras,
event_sender,
false, // don't disable provider streaming
TransformSystemMessages::Keep,
"AzureOpenAI".to_string(),
false, // don't squash text contents
)
.await
}
}

Expand Down
Loading

0 comments on commit e5b1b68

Please sign in to comment.