diff --git a/core/src/lib.rs b/core/src/lib.rs index 945312d0f1e8..f3cfaffbfe71 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -56,6 +56,7 @@ pub mod providers { pub mod anthropic; pub mod deepseek; pub mod google_ai_studio; + pub mod openai_compatible_helpers; pub mod togetherai; } pub mod http { diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index 29d8145b061d..7265037d0dc9 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -1,14 +1,11 @@ -use crate::providers::chat_messages::AssistantChatMessage; use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::ChatFunction; use crate::providers::llm::Tokens; use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::logprobs_from_choices; -use crate::providers::openai::{ - chat_completion, completion, embed, streamed_chat_completion, streamed_completion, - to_openai_messages, OpenAILLM, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::openai::completion; +use crate::providers::openai::embed; +use crate::providers::openai::streamed_completion; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async}; use crate::providers::tiktoken::tiktoken::{ @@ -25,10 +22,13 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io::prelude::*; -use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai::OpenAILLM; +use super::openai_compatible_helpers::openai_compatible_chat_completion; +use super::openai_compatible_helpers::TransformSystemMessages; + #[derive(Serialize, Deserialize, Debug, Clone)] struct AzureOpenAIScaleSettings { scale_type: String, @@ -436,7 +436,7 @@ impl LLM for AzureOpenAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -444,132 +444,30 @@ impl LLM for AzureOpenAILLM { extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let (openai_user, response_format, reasoning_effort) = match &extras { - None => (None, None, None), - Some(v) => ( - match v.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - match v.get("response_format") { - Some(Value::String(f)) => Some(f.to_string()), - _ => None, - }, - match v.get("reasoning_effort") { - Some(Value::String(r)) => Some(r.to_string()), - _ => None, - }, - ), - }; - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - let openai_messages = to_openai_messages(messages, &self.model_id.clone().unwrap())?; - - let (c, request_id) = match event_sender { - Some(_) => { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - None, - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - event_sender, - ) - .await? - } - None => { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - None, - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - ) - .await? - } - }; - - // println!("COMPLETION: {:?}", c); - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::AzureOpenAI.to_string(), - model: self.model_id.clone().unwrap(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.model_id.clone().unwrap(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + extras, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "AzureOpenAI".to_string(), + false, // don't squash text contents + ) + .await } } diff --git a/core/src/providers/deepseek.rs b/core/src/providers/deepseek.rs index d3c49e4ea7cd..63dc865857c1 100644 --- a/core/src/providers/deepseek.rs +++ b/core/src/providers/deepseek.rs @@ -1,12 +1,9 @@ -use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use std::sync::Arc; + +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::Embedder; use crate::providers::llm::ChatFunction; -use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::{ - chat_completion, logprobs_from_choices, streamed_chat_completion, to_openai_messages, - OpenAIChatMessage, OpenAIChatMessageContent, OpenAIContentBlock, OpenAITextContent, - OpenAITextContentType, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLM}; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; @@ -18,10 +15,12 @@ use async_trait::async_trait; use hyper::Uri; use parking_lot::RwLock; use serde_json::Value; -use std::str::FromStr; -use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, TransformSystemMessages, +}; + pub struct DeepseekLLM { id: String, api_key: Option, @@ -115,7 +114,7 @@ impl LLM for DeepseekLLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -123,138 +122,30 @@ impl LLM for DeepseekLLM { _extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - // Deepseek doesn't work with the new chat message content format. - // We have to modify the messages contents to use the "String" format. - let openai_messages = to_openai_messages(messages, &self.id)? - .into_iter() - .filter_map(|m| match m.content { - None => Some(m), - Some(OpenAIChatMessageContent::String(_)) => Some(m), - Some(OpenAIChatMessageContent::Structured(contents)) => { - // Find the first text content, and use it to make a string content. - let content = contents.into_iter().find_map(|c| match c { - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text, - .. - }) => Some(OpenAIChatMessageContent::String(text)), - _ => None, - }); - - Some(OpenAIChatMessage { - role: m.role, - name: m.name, - tool_call_id: m.tool_call_id, - tool_calls: m.tool_calls, - content, - }) - } - }) - .collect::>(); - - let is_streaming = event_sender.is_some(); - - let (c, request_id) = if is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - ) - .await? - }; - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::Deepseek.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + None, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "DeepSeek".to_string(), + false, // don't squash text contents + ) + .await } } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index e839c727a77f..406c50cd3445 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1,12 +1,8 @@ -use crate::providers::chat_messages::{ - AssistantChatMessage, ChatMessage, ContentBlock, MixedContent, -}; +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::{Embedder, EmbedderVector}; -use crate::providers::llm::{ChatFunction, ChatFunctionCall}; -use crate::providers::llm::{ - ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM, -}; -use crate::providers::llm::{LLMChatLogprob, Tokens}; +use crate::providers::llm::ChatFunction; +use crate::providers::llm::Tokens; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{ batch_tokenize_async, cl100k_base_singleton, o200k_base_singleton, p50k_base_singleton, @@ -15,7 +11,6 @@ use crate::providers::tiktoken::tiktoken::{ use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; use crate::run::Credentials; use crate::utils; -use crate::utils::ParseError; use anyhow::{anyhow, Result}; use async_trait::async_trait; use eventsource_client as es; @@ -30,13 +25,14 @@ use serde_json::json; use serde_json::Value; use std::collections::HashMap; use std::io::prelude::*; -use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; use tokio::time::timeout; -use super::llm::TopLogprob; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, OpenAIError, TransformSystemMessages, +}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Usage { @@ -82,526 +78,6 @@ pub struct Completion { pub usage: Option, } -/// -/// Tools implementation types. -/// - -// Input types. - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIToolType { - Function, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIFunction { - name: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIFunctionCall { - r#type: OpenAIToolType, - function: OpenAIFunction, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIToolControl { - Auto, - Required, - None, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIToolChoice { - OpenAIToolControl(OpenAIToolControl), - OpenAIFunctionCall(OpenAIFunctionCall), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolFunction { - pub name: String, - pub description: Option, - pub parameters: Option, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAITool { - pub r#type: OpenAIToolType, - pub function: OpenAIToolFunction, -} - -impl FromStr for OpenAIToolChoice { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match s { - "auto" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::Auto)), - "any" => Ok(OpenAIToolChoice::OpenAIToolControl( - OpenAIToolControl::Required, - )), - "none" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::None)), - _ => { - let function = OpenAIFunctionCall { - r#type: OpenAIToolType::Function, - function: OpenAIFunction { - name: s.to_string(), - }, - }; - Ok(OpenAIToolChoice::OpenAIFunctionCall(function)) - } - } - } -} - -// Outputs types. - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolCallFunction { - name: String, - arguments: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - r#type: OpenAIToolType, - pub function: OpenAIToolCallFunction, -} - -impl TryFrom<&ChatFunctionCall> for OpenAIToolCall { - type Error = anyhow::Error; - - fn try_from(cf: &ChatFunctionCall) -> Result { - Ok(OpenAIToolCall { - id: Some(cf.id.clone()), - r#type: OpenAIToolType::Function, - function: OpenAIToolCallFunction { - name: cf.name.clone(), - arguments: cf.arguments.clone(), - }, - }) - } -} - -impl TryFrom<&OpenAIToolCall> for ChatFunctionCall { - type Error = anyhow::Error; - - fn try_from(tc: &OpenAIToolCall) -> Result { - let id = tc - .id - .as_ref() - .ok_or_else(|| anyhow!("Missing tool call id."))?; - - Ok(ChatFunctionCall { - id: id.clone(), - name: tc.function.name.clone(), - arguments: tc.function.arguments.clone(), - }) - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIChatMessageRole { - Assistant, - Function, - System, - Developer, - Tool, - User, -} - -impl From<&ChatMessageRole> for OpenAIChatMessageRole { - fn from(role: &ChatMessageRole) -> Self { - match role { - ChatMessageRole::Assistant => OpenAIChatMessageRole::Assistant, - ChatMessageRole::Function => OpenAIChatMessageRole::Function, - ChatMessageRole::System => OpenAIChatMessageRole::System, - ChatMessageRole::User => OpenAIChatMessageRole::User, - } - } -} - -impl FromStr for OpenAIChatMessageRole { - type Err = ParseError; - fn from_str(s: &str) -> Result { - match s { - "system" => Ok(OpenAIChatMessageRole::System), - "user" => Ok(OpenAIChatMessageRole::User), - "assistant" => Ok(OpenAIChatMessageRole::Assistant), - "function" => Ok(OpenAIChatMessageRole::Tool), - _ => Err(ParseError::with_message("Unknown OpenAIChatMessageRole"))?, - } - } -} - -impl From for ChatMessageRole { - fn from(value: OpenAIChatMessageRole) -> Self { - match value { - OpenAIChatMessageRole::Assistant => ChatMessageRole::Assistant, - OpenAIChatMessageRole::Function => ChatMessageRole::Function, - OpenAIChatMessageRole::System => ChatMessageRole::System, - OpenAIChatMessageRole::Developer => ChatMessageRole::System, - OpenAIChatMessageRole::Tool => ChatMessageRole::Function, - OpenAIChatMessageRole::User => ChatMessageRole::User, - } - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "snake_case")] -pub enum OpenAITextContentType { - Text, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAITextContent { - #[serde(rename = "type")] - pub r#type: OpenAITextContentType, - pub text: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIImageUrlContent { - pub url: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "snake_case")] -pub enum OpenAIImageContentType { - ImageUrl, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIImageContent { - pub r#type: OpenAIImageContentType, - pub image_url: OpenAIImageUrlContent, -} - -// Define an enum for mixed content -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIContentBlock { - TextContent(OpenAITextContent), - ImageContent(OpenAIImageContent), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIChatMessageContent { - Structured(Vec), - String(String), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIChatMessage { - pub role: OpenAIChatMessageRole, - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAICompletionChatMessage { - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - pub role: OpenAIChatMessageRole, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAITopLogprob { - pub token: String, - pub logprob: f32, - pub bytes: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoiceLogprob { - pub token: String, - pub logprob: f32, - pub bytes: Option>, - pub top_logprobs: Vec, -} - -impl From for TopLogprob { - fn from(top_logprob: OpenAITopLogprob) -> Self { - TopLogprob { - token: top_logprob.token, - logprob: top_logprob.logprob, - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoiceLogprobs { - pub content: Vec, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoice { - pub message: OpenAICompletionChatMessage, - pub index: usize, - pub finish_reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatCompletion { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, -} - -// This code performs a type conversion with information loss when converting to ChatFunctionCall. -// It only supports one tool call, so it takes the first one from the vector of OpenAIToolCall, -// hence potentially discarding other tool calls. -impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { - type Error = anyhow::Error; - - fn try_from(cm: &OpenAICompletionChatMessage) -> Result { - let role = ChatMessageRole::from(cm.role.clone()); - let content = match cm.content.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }; - - let function_calls = if let Some(tool_calls) = cm.tool_calls.as_ref() { - let cfc = tool_calls - .into_iter() - .map(|tc| ChatFunctionCall::try_from(tc)) - .collect::, _>>()?; - - Some(cfc) - } else { - None - }; - - let function_call = if let Some(fcs) = function_calls.as_ref() { - match fcs.first() { - Some(fc) => Some(fc), - None => None, - } - .cloned() - } else { - None - }; - - let name = match cm.name.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }; - - Ok(AssistantChatMessage { - content, - role, - name, - function_call, - function_calls, - }) - } -} - -impl TryFrom<&ContentBlock> for OpenAIChatMessageContent { - type Error = anyhow::Error; - - fn try_from(cm: &ContentBlock) -> Result { - match cm { - ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![ - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: t.clone(), - }), - ])), - ContentBlock::Mixed(m) => { - let content: Vec = m - .into_iter() - .map(|mb| match mb { - MixedContent::TextContent(tc) => { - Ok(OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: tc.text.clone(), - })) - } - MixedContent::ImageContent(ic) => { - Ok(OpenAIContentBlock::ImageContent(OpenAIImageContent { - r#type: OpenAIImageContentType::ImageUrl, - image_url: OpenAIImageUrlContent { - url: ic.image_url.url.clone(), - }, - })) - } - }) - .collect::>>()?; - - Ok(OpenAIChatMessageContent::Structured(content)) - } - } - } -} - -impl TryFrom<&String> for OpenAIChatMessageContent { - type Error = anyhow::Error; - - fn try_from(t: &String) -> Result { - Ok(OpenAIChatMessageContent::Structured(vec![ - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: t.clone(), - }), - ])) - } -} - -impl TryFrom<&ChatMessage> for OpenAIChatMessage { - type Error = anyhow::Error; - - fn try_from(cm: &ChatMessage) -> Result { - match cm { - ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { - content: match &assistant_msg.content { - Some(c) => Some(OpenAIChatMessageContent::try_from(c)?), - None => None, - }, - name: assistant_msg.name.clone(), - role: OpenAIChatMessageRole::from(&assistant_msg.role), - tool_calls: match assistant_msg.function_calls.as_ref() { - Some(fc) => Some( - fc.into_iter() - .map(|f| OpenAIToolCall::try_from(f)) - .collect::, _>>()?, - ), - None => None, - }, - tool_call_id: None, - }), - ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?), - name: None, - role: OpenAIChatMessageRole::Tool, - tool_calls: None, - tool_call_id: Some(function_msg.function_call_id.clone()), - }), - ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?), - name: None, - role: OpenAIChatMessageRole::from(&system_msg.role), - tool_calls: None, - tool_call_id: None, - }), - ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?), - name: user_msg.name.clone(), - role: OpenAIChatMessageRole::from(&user_msg.role), - tool_calls: None, - tool_call_id: None, - }), - } - } -} - -impl TryFrom<&ChatFunction> for OpenAITool { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunction) -> Result { - Ok(OpenAITool { - r#type: OpenAIToolType::Function, - function: OpenAIToolFunction { - name: f.name.clone(), - description: f.description.clone(), - parameters: f.parameters.clone(), - }, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChatDelta { - pub delta: Value, - pub index: usize, - pub finish_reason: Option, - pub logprobs: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChatChunk { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct InnerError { - pub message: String, - #[serde(alias = "type")] - pub _type: String, - pub param: Option, - pub internal_message: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIError { - pub error: InnerError, -} - -impl OpenAIError { - pub fn message(&self) -> String { - match self.error.internal_message { - Some(ref msg) => format!( - "OpenAIError: [{}] {} internal_message={}", - self.error._type, self.error.message, msg, - ), - None => format!("OpenAIError: [{}] {}", self.error._type, self.error.message,), - } - } - - pub fn retryable(&self) -> bool { - match self.error._type.as_str() { - "requests" => true, - "server_error" => match &self.error.internal_message { - Some(message) if message.contains("retry") => true, - _ => false, - }, - _ => false, - } - } - - pub fn retryable_streamed(&self, status: StatusCode) -> bool { - if status == StatusCode::TOO_MANY_REQUESTS { - return true; - } - if status.is_server_error() { - return true; - } - match self.error._type.as_str() { - "server_error" => match self.error.internal_message { - Some(_) => true, - None => false, - }, - _ => false, - } - } -} - /// /// Shared streamed/non-streamed chat/completion handling code (used by both OpenAILLM and /// AzureOpenAILLM). @@ -677,8 +153,6 @@ pub async fn streamed_completion( body["stop"] = json!(stop); } - // println!("BODY: {}", body.to_string()); - let client = builder .body(body.to_string()) .reconnect( @@ -729,7 +203,7 @@ pub async fn streamed_completion( { true => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -738,7 +212,7 @@ pub async fn streamed_completion( })?, false => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: None, })?, } @@ -757,517 +231,57 @@ pub async fn streamed_completion( } }; - // UTF-8 length of the prompt (as used by the API for text_offset). - let prompt_len = prompt.chars().count(); - - // Only stream if choices is length 1 but should always be the case. - match event_sender.as_ref() { - Some(sender) => { - let mut text = completion.choices[0].text.clone(); - let mut tokens = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.tokens.clone()), - None => None, - }; - let mut logprobs = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.token_logprobs.clone()), - None => None, - }; - let text_offset = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.text_offset.clone()), - None => None, - }; - if index == 0 && text_offset.is_some() { - let mut token_offset: usize = 0; - for o in text_offset.as_ref().unwrap() { - if *o < prompt_len { - token_offset += 1; - } - } - text = text.chars().skip(prompt_len).collect::(); - tokens = match tokens { - Some(t) => Some(t[token_offset..].to_vec()), - None => None, - }; - logprobs = match logprobs { - Some(l) => Some(l[token_offset..].to_vec()), - None => None, - }; - } - - if text.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": text, - "tokens": tokens, - "logprobs": logprobs, - }, - })); - } - } - None => (), - }; - completions.lock().push(completion); - } - }, - None => { - println!("UNEXPECTED NONE"); - break 'stream; - } - }, - Err(e) => { - match e { - es::Error::UnexpectedResponse(r) => { - let status = StatusCode::from_u16(r.status())?; - let headers = r.headers()?; - let request_id = match headers.get("x-request-id") { - Some(v) => Some(v.to_string()), - None => None, - }; - let b = r.body_bytes().await?; - - let error: Result = serde_json::from_slice(&b); - match error { - Ok(error) => { - match error.retryable_streamed(status) { - true => Err(ModelError { - request_id, - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id, - message: error.message(), - retryable: None, - }), - } - }?, - Err(_) => { - Err(anyhow!( - "Error streaming tokens from OpenAI: status={} data={}", - status, - String::from_utf8_lossy(&b) - ))?; - } - } - } - _ => { - Err(anyhow!("Error streaming tokens from OpenAI: {:?}", e))?; - } - } - break 'stream; - } - } - } - - let completion = { - let mut guard = completions.lock(); - let mut c = match guard.len() { - 0 => Err(anyhow!("No completions received from OpenAI")), - _ => Ok(guard[0].clone()), - }?; - guard.remove(0); - for i in 0..guard.len() { - let a = guard[i].clone(); - if a.choices.len() != c.choices.len() { - Err(anyhow!( - "Inconsistent number of choices in streamed completions" - ))?; - } - for j in 0..c.choices.len() { - c.choices[j].finish_reason = a.choices.get(j).unwrap().finish_reason.clone(); - // OpenAI does the bytes merging for us <3. - c.choices[j].text = format!("{}{}", c.choices[j].text, a.choices[j].text); - - match c.choices[j].logprobs.as_mut() { - Some(c_logprobs) => match a.choices[j].logprobs.as_ref() { - Some(a_logprobs) => { - c_logprobs.tokens.extend(a_logprobs.tokens.clone()); - c_logprobs - .token_logprobs - .extend(a_logprobs.token_logprobs.clone()); - c_logprobs - .text_offset - .extend(a_logprobs.text_offset.clone()); - match c_logprobs.top_logprobs.as_mut() { - Some(c_top_logprobs) => match a_logprobs.top_logprobs.as_ref() { - Some(a_top_logprobs) => { - c_top_logprobs.extend(a_top_logprobs.clone()); - } - None => (), - }, - None => (), - } - } - None => (), - }, - None => (), - } - } - } - c - }; - - Ok((completion, request_id)) -} - -pub async fn completion( - uri: Uri, - api_key: String, - organization_id: Option, - model_id: Option, - prompt: &str, - max_tokens: Option, - temperature: f32, - n: usize, - logprobs: Option, - echo: bool, - stop: &Vec, - frequency_penalty: f32, - presence_penalty: f32, - top_p: f32, - user: Option, -) -> Result<(Completion, Option)> { - // let https = HttpsConnector::new(); - // let cli = Client::builder().build::<_, hyper::Body>(https); - - let mut body = json!({ - "prompt": prompt, - "temperature": temperature, - "n": n, - "logprobs": logprobs, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "top_p": top_p, - }); - if user.is_some() { - body["user"] = json!(user); - } - if model_id.is_some() { - body["model"] = json!(model_id); - } - if let Some(mt) = max_tokens { - body["max_tokens"] = mt.into(); - } - if !stop.is_empty() { - body["stop"] = json!(stop); - } - - match model_id { - None => (), - Some(model_id) => { - body["model"] = json!(model_id); - // `gpt-3.5-turbo-instruct` does not support `echo` - if !model_id.starts_with("gpt-3.5-turbo-instruct") { - body["echo"] = json!(echo); - } - } - }; - - let mut req = reqwest::Client::new() - .post(uri.to_string()) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key.clone())) - .header("api-key", api_key.clone()); - - if let Some(organization_id) = organization_id { - req = req.header("OpenAI-Organization", organization_id); - } - - req = req.json(&body); - - let res = match timeout(Duration::new(180, 0), req.send()).await { - Ok(Ok(res)) => res, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout sending request to OpenAI after 180s"))?, - }; - - let res_headers = res.headers(); - let request_id = match res_headers.get("x-request-id") { - Some(request_id) => Some(request_id.to_str()?.to_string()), - None => None, - }; - - let body = match timeout(Duration::new(180, 0), res.bytes()).await { - Ok(Ok(body)) => body, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout reading response from OpenAI after 180s"))?, - }; - - let mut b: Vec = vec![]; - body.reader().read_to_end(&mut b)?; - let c: &[u8] = &b; - - let completion: Completion = match serde_json::from_slice(c) { - Ok(c) => Ok(c), - Err(_) => { - let error: OpenAIError = serde_json::from_slice(c)?; - match error.retryable() { - true => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 1, - retries: 1, - }), - }), - } - } - }?; - - Ok((completion, request_id)) -} - -pub async fn streamed_chat_completion( - uri: Uri, - api_key: String, - organization_id: Option, - model_id: Option, - messages: &Vec, - tools: Vec, - tool_choice: Option, - temperature: f32, - top_p: f32, - n: usize, - stop: &Vec, - max_tokens: Option, - presence_penalty: f32, - frequency_penalty: f32, - response_format: Option, - reasoning_effort: Option, - logprobs: Option, - top_logprobs: Option, - user: Option, - event_sender: Option>, -) -> Result<(OpenAIChatCompletion, Option)> { - let url = uri.to_string(); - - let mut builder = match es::ClientBuilder::for_url(url.as_str()) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.method(String::from("POST")).header( - "Authorization", - format!("Bearer {}", api_key.clone()).as_str(), - ) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.header("Content-Type", "application/json") { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.header("api-key", api_key.clone().as_str()) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - - if let Some(org_id) = organization_id { - builder = builder - .header("OpenAI-Organization", org_id.as_str()) - .map_err(|_| anyhow!("Error creating streamed client to OpenAI"))?; - } - - let mut body = json!({ - "messages": messages, - "temperature": temperature, - "top_p": top_p, - "n": n, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, - "stream": true, - "stream_options": HashMap::from([("include_usage", true)]), - }); - if user.is_some() { - body["user"] = json!(user); - } - if model_id.is_some() { - body["model"] = json!(model_id); - } - if let Some(mt) = max_tokens { - body["max_tokens"] = mt.into(); - } - if !stop.is_empty() { - body["stop"] = json!(stop); - } - - if tools.len() > 0 { - body["tools"] = json!(tools); - } - if let Some(tool_choice) = tool_choice { - body["tool_choice"] = json!(tool_choice); - } - if let Some(response_format) = response_format { - body["response_format"] = json!({ - "type": response_format, - }); - } - if let Some(reasoning_effort) = reasoning_effort { - body["reasoning_effort"] = json!(reasoning_effort); - } - if let Some(logprobs) = logprobs { - body["logprobs"] = json!(logprobs); - } - if let Some(top_logprobs) = top_logprobs { - body["top_logprobs"] = json!(top_logprobs); - } - - let client = builder - .body(body.to_string()) - .reconnect( - es::ReconnectOptions::reconnect(true) - .retry_initial(false) - .delay(Duration::from_secs(1)) - .backoff_factor(2) - .delay_max(Duration::from_secs(8)) - .build(), - ) - .build(); - - let mut stream = client.stream(); - - let chunks: Arc>> = Arc::new(Mutex::new(Vec::new())); - let mut usage = None; - let mut request_id: Option = None; - - 'stream: loop { - match stream.try_next().await { - Ok(e) => match e { - Some(es::SSE::Connected((_, headers))) => { - request_id = match headers.get("x-request-id") { - Some(v) => Some(v.to_string()), - None => None, - }; - } - Some(es::SSE::Comment(_)) => { - println!("UNEXPECTED COMMENT"); - } - Some(es::SSE::Event(e)) => match e.data.as_str() { - "[DONE]" => { - break 'stream; - } - _ => { - let index = { - let guard = chunks.lock(); - guard.len() - }; - - let chunk: ChatChunk = match serde_json::from_str(e.data.as_str()) { - Ok(c) => c, - Err(err) => { - let error: Result = - serde_json::from_str(e.data.as_str()); - match error { - Ok(error) => { - match error.retryable_streamed(StatusCode::OK) && index == 0 - { - true => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - })?, - false => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: None, - })?, - } - break 'stream; - } - Err(_) => { - Err(anyhow!( - "OpenAIError: failed parsing streamed \ - completion from OpenAI err={} data={}", - err, - e.data.as_str(), - ))?; - break 'stream; - } - } - } - }; - - // Store usage - match &chunk.usage { - Some(received_usage) => { - usage = Some(received_usage.clone()); - } - None => (), - }; - + // UTF-8 length of the prompt (as used by the API for text_offset). + let prompt_len = prompt.chars().count(); + // Only stream if choices is length 1 but should always be the case. match event_sender.as_ref() { Some(sender) => { - if chunk.choices.len() == 1 { - // we ignore the role for generating events - - // If we get `content` in the delta object we stream "tokens". - match chunk.choices[0].delta.get("content") { - None => (), - Some(content) => match content.as_str() { - None => (), - Some(s) => { - if s.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": s, - }, - })); - } - } - }, + let mut text = completion.choices[0].text.clone(); + let mut tokens = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.tokens.clone()), + None => None, + }; + let mut logprobs = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.token_logprobs.clone()), + None => None, + }; + let text_offset = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.text_offset.clone()), + None => None, + }; + if index == 0 && text_offset.is_some() { + let mut token_offset: usize = 0; + for o in text_offset.as_ref().unwrap() { + if *o < prompt_len { + token_offset += 1; + } + } + text = text.chars().skip(prompt_len).collect::(); + tokens = match tokens { + Some(t) => Some(t[token_offset..].to_vec()), + None => None, + }; + logprobs = match logprobs { + Some(l) => Some(l[token_offset..].to_vec()), + None => None, }; + } - // Emit a `function_call` event per tool_call. - if let Some(tool_calls) = chunk.choices[0] - .delta - .get("tool_calls") - .and_then(|v| v.as_array()) - { - tool_calls.iter().for_each(|tool_call| { - match tool_call.get("function") { - Some(f) => { - if let Some(Value::String(name)) = f.get("name") - { - let _ = sender.send(json!({ - "type": "function_call", - "content": { - "name": name, - }, - })); - } - } - _ => (), - } - }); - } + if text.len() > 0 { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": text, + "tokens": tokens, + "logprobs": logprobs, + }, + })); } } None => (), }; - - if !chunk.choices.is_empty() { - chunks.lock().push(chunk); - } + completions.lock().push(completion); } }, None => { @@ -1292,7 +306,7 @@ pub async fn streamed_chat_completion( match error.retryable_streamed(status) { true => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1301,7 +315,7 @@ pub async fn streamed_chat_completion( }), false => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: None, }), } @@ -1324,200 +338,87 @@ pub async fn streamed_chat_completion( } } - let mut completion = { - let guard = chunks.lock(); - let f = match guard.len() { - 0 => Err(anyhow!("No chunks received from OpenAI")), + let completion = { + let mut guard = completions.lock(); + let mut c = match guard.len() { + 0 => Err(anyhow!("No completions received from OpenAI")), _ => Ok(guard[0].clone()), }?; - - // merge logprobs from all choices of all chunks - let logprobs: Vec = guard - .iter() - .flat_map(|chunk| { - chunk - .choices - .iter() - .filter_map(|choice| choice.logprobs.as_ref().map(|lp| lp.content.clone())) - }) - .flatten() - .collect(); - - let mut c = OpenAIChatCompletion { - id: f.id.clone(), - object: f.object.clone(), - created: f.created, - choices: f - .choices - .iter() - .map(|c| OpenAIChatChoice { - message: OpenAICompletionChatMessage { - content: Some("".to_string()), - name: None, - role: OpenAIChatMessageRole::System, - tool_calls: None, - tool_call_id: None, - }, - index: c.index, - finish_reason: None, - logprobs: match logprobs.len() { - 0 => None, - _ => Some(OpenAIChatChoiceLogprobs { - content: logprobs.clone(), - }), - }, - }) - .collect::>(), - usage, - }; - + guard.remove(0); for i in 0..guard.len() { let a = guard[i].clone(); - if a.choices.len() != f.choices.len() { - Err(anyhow!("Inconsistent number of choices in streamed chunks"))?; + if a.choices.len() != c.choices.len() { + Err(anyhow!( + "Inconsistent number of choices in streamed completions" + ))?; } - for j in 0..a.choices.len() { - match a.choices.get(j).unwrap().finish_reason.clone() { - None => (), - Some(f) => c.choices[j].finish_reason = Some(f), - }; - - match a.choices[j].delta.get("role") { - None => (), - Some(role) => match role.as_str() { - None => (), - Some(r) => { - c.choices[j].message.role = OpenAIChatMessageRole::from_str(r)?; - } - }, - }; + for j in 0..c.choices.len() { + c.choices[j].finish_reason = a.choices.get(j).unwrap().finish_reason.clone(); + // OpenAI does the bytes merging for us <3. + c.choices[j].text = format!("{}{}", c.choices[j].text, a.choices[j].text); - match a.choices[j].delta.get("content") { - None => (), - Some(content) => match content.as_str() { - None => (), - Some(s) => { - c.choices[j].message.content = Some(format!( - "{}{}", - c.choices[j] - .message - .content - .as_ref() - .unwrap_or(&String::new()), - s - )); - } - }, - }; - - if let Some(tool_calls) = a.choices[j] - .delta - .get("tool_calls") - .and_then(|v| v.as_array()) - { - for tool_call in tool_calls { - match ( - tool_call.get("type").and_then(|v| v.as_str()), - tool_call.get("id").and_then(|v| v.as_str()), - tool_call.get("function"), - ) { - (Some("function"), Some(id), Some(f)) => { - if let Some(Value::String(name)) = f.get("name") { - c.choices[j] - .message - .tool_calls - .get_or_insert_with(Vec::new) - .push(OpenAIToolCall { - id: Some(id.to_string()), - r#type: OpenAIToolType::Function, - function: OpenAIToolCallFunction { - name: name.clone(), - arguments: String::new(), - }, - }); - } - } - (None, None, Some(f)) => { - if let (Some(Value::Number(idx)), Some(Value::String(a))) = - (tool_call.get("index"), f.get("arguments")) - { - let index: usize = idx - .as_u64() - .ok_or_else(|| anyhow!("Missing index for tools"))? - .try_into() - .map_err(|e| { - anyhow!("Invalid index value for tools: {:?}", e) - })?; - - let tool_calls = c.choices[j] - .message - .tool_calls - .as_mut() - .ok_or(anyhow!("Missing tool calls"))?; - - if index >= tool_calls.len() { - return Err(anyhow!( - "Index out-of-bound for tool_calls: {}", - index - )); + match c.choices[j].logprobs.as_mut() { + Some(c_logprobs) => match a.choices[j].logprobs.as_ref() { + Some(a_logprobs) => { + c_logprobs.tokens.extend(a_logprobs.tokens.clone()); + c_logprobs + .token_logprobs + .extend(a_logprobs.token_logprobs.clone()); + c_logprobs + .text_offset + .extend(a_logprobs.text_offset.clone()); + match c_logprobs.top_logprobs.as_mut() { + Some(c_top_logprobs) => match a_logprobs.top_logprobs.as_ref() { + Some(a_top_logprobs) => { + c_top_logprobs.extend(a_top_logprobs.clone()); } - - tool_calls[index].function.arguments += a; - } + None => (), + }, + None => (), } - _ => (), } - } + None => (), + }, + None => (), } } } c }; - // for all messages, edit the content and strip leading and trailing spaces and \n - for m in completion.choices.iter_mut() { - m.message.content = match m.message.content.as_ref() { - None => None, - Some(c) => Some(c.trim().to_string()), - }; - } - Ok((completion, request_id)) } -pub async fn chat_completion( +pub async fn completion( uri: Uri, api_key: String, organization_id: Option, model_id: Option, - messages: &Vec, - tools: Vec, - tool_choice: Option, + prompt: &str, + max_tokens: Option, temperature: f32, - top_p: f32, n: usize, + logprobs: Option, + echo: bool, stop: &Vec, - max_tokens: Option, - presence_penalty: f32, frequency_penalty: f32, - response_format: Option, - reasoning_effort: Option, - logprobs: Option, - top_logprobs: Option, + presence_penalty: f32, + top_p: f32, user: Option, -) -> Result<(OpenAIChatCompletion, Option)> { +) -> Result<(Completion, Option)> { let mut body = json!({ - "messages": messages, + "prompt": prompt, "temperature": temperature, - "top_p": top_p, "n": n, - "presence_penalty": presence_penalty, + "logprobs": logprobs, "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "top_p": top_p, }); if user.is_some() { body["user"] = json!(user); } - if let Some(model_id) = model_id { + if model_id.is_some() { body["model"] = json!(model_id); } if let Some(mt) = max_tokens { @@ -1527,43 +428,28 @@ pub async fn chat_completion( body["stop"] = json!(stop); } - if let Some(response_format) = response_format { - body["response_format"] = json!({ - "type": response_format, - }); - } - if tools.len() > 0 { - body["tools"] = json!(tools); - } - if let Some(tool_choice) = tool_choice { - body["tool_choice"] = json!(tool_choice); - } - if let Some(reasoning_effort) = reasoning_effort { - body["reasoning_effort"] = json!(reasoning_effort); - } - if let Some(logprobs) = logprobs { - body["logprobs"] = json!(logprobs); - } - if let Some(top_logprobs) = top_logprobs { - body["top_logprobs"] = json!(top_logprobs); - } + match model_id { + None => (), + Some(model_id) => { + body["model"] = json!(model_id); + // `gpt-3.5-turbo-instruct` does not support `echo` + if !model_id.starts_with("gpt-3.5-turbo-instruct") { + body["echo"] = json!(echo); + } + } + }; let mut req = reqwest::Client::new() .post(uri.to_string()) .header("Content-Type", "application/json") - // This one is for `openai`. .header("Authorization", format!("Bearer {}", api_key.clone())) - // This one is for `azure_openai`. .header("api-key", api_key.clone()); if let Some(organization_id) = organization_id { - req = req.header( - "OpenAI-Organization", - &format!("{}", organization_id.clone()), - ); + req = req.header("OpenAI-Organization", organization_id); } - let req = req.json(&body); + req = req.json(&body); let res = match timeout(Duration::new(180, 0), req.send()).await { Ok(Ok(res)) => res, @@ -1587,14 +473,14 @@ pub async fn chat_completion( body.reader().read_to_end(&mut b)?; let c: &[u8] = &b; - let mut completion: OpenAIChatCompletion = match serde_json::from_slice(c) { + let completion: Completion = match serde_json::from_slice(c) { Ok(c) => Ok(c), Err(_) => { let error: OpenAIError = serde_json::from_slice(c)?; match error.retryable() { true => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1603,7 +489,7 @@ pub async fn chat_completion( }), false => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 1, @@ -1614,45 +500,9 @@ pub async fn chat_completion( } }?; - // for all messages, edit the content and strip leading and trailing spaces and \n - for m in completion.choices.iter_mut() { - m.message.content = match m.message.content.as_ref() { - None => None, - Some(c) => Some(c.trim().to_string()), - }; - } - Ok((completion, request_id)) } -pub fn logprobs_from_choices(choices: &Vec) -> Option> { - let lp: Vec = choices - .iter() - .filter_map(|choice| choice.logprobs.as_ref()) - .flat_map(|lp| { - lp.content.iter().map(|content_logprob| LLMChatLogprob { - token: content_logprob.token.clone(), - logprob: content_logprob.logprob, - top_logprobs: match content_logprob.top_logprobs.len() { - 0 => None, - _ => Some( - content_logprob - .top_logprobs - .iter() - .map(|top_logprob| top_logprob.clone().into()) - .collect(), - ), - }, - }) - }) - .collect(); - - match lp.len() { - 0 => None, - _ => Some(lp), - } -} - /// /// Shared streamed/non-streamed chat/completion handling code (used by both OpenAILLM and /// AzureOpenAILLM). @@ -1735,7 +585,7 @@ pub async fn embed( match error.retryable() { true => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1744,7 +594,7 @@ pub async fn embed( }), false => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 1, @@ -1822,29 +672,6 @@ impl OpenAILLM { } } -pub fn to_openai_messages( - messages: &Vec, - model_id: &str, -) -> Result, anyhow::Error> { - let mut oai_messages = messages - .iter() - .map(|m| OpenAIChatMessage::try_from(m)) - .collect::>>()? - .into_iter() - // [o1-mini] O1 mini does not support system messages, so we filter them out. - .filter(|m| m.role != OpenAIChatMessageRole::System || !model_id.starts_with("o1-mini")) - .collect::>(); - - // [o1] O1 uses `developer` messages instead of `system` messages. - for m in oai_messages.iter_mut() { - if m.role == OpenAIChatMessageRole::System && model_id.starts_with("o1") { - m.role = OpenAIChatMessageRole::Developer; - } - } - - Ok(oai_messages) -} - #[async_trait] impl LLM for OpenAILLM { fn id(&self) -> String { @@ -2100,7 +927,7 @@ impl LLM for OpenAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -2108,162 +935,40 @@ impl LLM for OpenAILLM { extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let (openai_org_id, openai_user, response_format, reasoning_effort) = match &extras { - None => (None, None, None, None), - Some(v) => ( - match v.get("openai_organization_id") { - Some(Value::String(o)) => Some(o.to_string()), - _ => None, - }, - match v.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - match v.get("response_format") { - Some(Value::String(f)) => Some(f.to_string()), - _ => None, - }, - match v.get("reasoning_effort") { - Some(Value::String(r)) => Some(r.to_string()), - _ => None, - }, - ), - }; - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - let openai_messages = to_openai_messages(messages, &self.id)?; - - // [o1] Hack for OpenAI `o1*` models to simulate streaming. - let is_streaming = event_sender.is_some(); let model_is_o1 = self.id.as_str().starts_with("o1"); - - let (c, request_id) = if !model_is_o1 && is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - openai_org_id, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - // [o1] O1 models do not support custom temperature. - if !model_is_o1 { temperature } else { 1.0 }, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - openai_org_id, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - // [o1] O1 models do not support custom temperature. - if !model_is_o1 { temperature } else { 1.0 }, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - ) - .await? - }; - - // [o1] Hack for OpenAI `o1*` models to simulate streaming. - if model_is_o1 && is_streaming { - let sender = event_sender.as_ref().unwrap(); - for choice in &c.choices { - if let Some(content) = &choice.message.content { - // Split content into smaller chunks to simulate streaming. - for chunk in content - .chars() - .collect::>() - .chunks(4) - .map(|c| c.iter().collect::()) - { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": chunk, - }, - })); - // Add a small delay to simulate real-time streaming. - tokio::time::sleep(std::time::Duration::from_millis(20)).await; - } - } - } - } - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::OpenAI.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + let model_is_o1_mini = self.id.as_str().starts_with("o1-mini"); + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + &messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + extras, + event_sender, + model_is_o1, // disable provider streaming if model is o1. + // If model is o1-mini, we remove system messages. + // If model is o1, we replace system messages with developer messages. + if model_is_o1_mini { + TransformSystemMessages::Remove + } else if model_is_o1 { + TransformSystemMessages::ReplaceWithDeveloper + } else { + TransformSystemMessages::Keep + }, + "OpenAI".to_string(), + false, // Don't squash text contents. + ) + .await } } @@ -2462,33 +1167,6 @@ impl Provider for OpenAIProvider { ) .await?; - // let mut embedder = self.embedder(String::from("text-embedding-ada-002")); - // embedder.initialize(Credentials::new()).await?; - - // let _v = embedder.embed("Hello 😊", None).await?; - // println!("EMBEDDING SIZE: {}", v.vector.len()); - - // llm = self.llm(String::from("gpt-3.5-turbo")); - // llm.initialize(Credentials::new()).await?; - - // let messages = vec![ - // // ChatMessage { - // // role: String::from("system"), - // // content: String::from( - // // "You're a an assistant. Answer as concisely and precisely as possible.", - // // ), - // // }, - // ChatMessage { - // role: String::from("user"), - // content: String::from("How can I calculate the area of a circle?"), - // }, - // ]; - - // let c = llm - // .chat(&messages, 0.7, None, 1, &vec![], None, None, None, None) - // .await?; - // println!("CHAT COMPLETION SIZE: {:?}", c); - utils::done("Test successfully completed! OpenAI is ready to use."); Ok(()) diff --git a/core/src/providers/openai_compatible_helpers.rs b/core/src/providers/openai_compatible_helpers.rs new file mode 100644 index 000000000000..04c4f0ccdfa5 --- /dev/null +++ b/core/src/providers/openai_compatible_helpers.rs @@ -0,0 +1,1529 @@ +use std::{collections::HashMap, str::FromStr, time::Duration}; + +use crate::{ + providers::{llm::LLMTokenUsage, provider::ProviderID}, + utils::{self, ParseError}, +}; +use anyhow::{anyhow, Result}; +use eventsource_client as es; +use eventsource_client::Client as ESClient; +use futures::TryStreamExt; +use http::StatusCode; +use hyper::{body::Buf, Uri}; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::io::prelude::*; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; +use tokio::time::timeout; + +use super::{ + chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent}, + llm::{ + ChatFunction, ChatFunctionCall, ChatMessageRole, LLMChatGeneration, LLMChatLogprob, + TopLogprob, + }, + provider::{ModelError, ModelErrorRetryOptions}, +}; + +// Input types. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIToolType { + Function, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIFunction { + name: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIFunctionCall { + r#type: OpenAIToolType, + function: OpenAIFunction, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIToolControl { + Auto, + Required, + None, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIToolChoice { + OpenAIToolControl(OpenAIToolControl), + OpenAIFunctionCall(OpenAIFunctionCall), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolFunction { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAITool { + pub r#type: OpenAIToolType, + pub function: OpenAIToolFunction, +} + +impl FromStr for OpenAIToolChoice { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match s { + "auto" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::Auto)), + "any" => Ok(OpenAIToolChoice::OpenAIToolControl( + OpenAIToolControl::Required, + )), + "none" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::None)), + _ => { + let function = OpenAIFunctionCall { + r#type: OpenAIToolType::Function, + function: OpenAIFunction { + name: s.to_string(), + }, + }; + Ok(OpenAIToolChoice::OpenAIFunctionCall(function)) + } + } + } +} + +// Outputs types. + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Usage { + pub prompt_tokens: u64, + pub completion_tokens: Option, + pub total_tokens: u64, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolCallFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + r#type: OpenAIToolType, + pub function: OpenAIToolCallFunction, +} + +impl TryFrom<&ChatFunctionCall> for OpenAIToolCall { + type Error = anyhow::Error; + + fn try_from(cf: &ChatFunctionCall) -> Result { + Ok(OpenAIToolCall { + id: Some(cf.id.clone()), + r#type: OpenAIToolType::Function, + function: OpenAIToolCallFunction { + name: cf.name.clone(), + arguments: cf.arguments.clone(), + }, + }) + } +} + +impl TryFrom<&OpenAIToolCall> for ChatFunctionCall { + type Error = anyhow::Error; + + fn try_from(tc: &OpenAIToolCall) -> Result { + let id = tc + .id + .as_ref() + .ok_or_else(|| anyhow!("Missing tool call id."))?; + + Ok(ChatFunctionCall { + id: id.clone(), + name: tc.function.name.clone(), + arguments: tc.function.arguments.clone(), + }) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIChatMessageRole { + Assistant, + Function, + System, + Developer, + Tool, + User, +} + +impl From<&ChatMessageRole> for OpenAIChatMessageRole { + fn from(role: &ChatMessageRole) -> Self { + match role { + ChatMessageRole::Assistant => OpenAIChatMessageRole::Assistant, + ChatMessageRole::Function => OpenAIChatMessageRole::Function, + ChatMessageRole::System => OpenAIChatMessageRole::System, + ChatMessageRole::User => OpenAIChatMessageRole::User, + } + } +} + +impl FromStr for OpenAIChatMessageRole { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "system" => Ok(OpenAIChatMessageRole::System), + "user" => Ok(OpenAIChatMessageRole::User), + "assistant" => Ok(OpenAIChatMessageRole::Assistant), + "function" => Ok(OpenAIChatMessageRole::Tool), + _ => Err(ParseError::with_message("Unknown OpenAIChatMessageRole"))?, + } + } +} + +impl From for ChatMessageRole { + fn from(value: OpenAIChatMessageRole) -> Self { + match value { + OpenAIChatMessageRole::Assistant => ChatMessageRole::Assistant, + OpenAIChatMessageRole::Function => ChatMessageRole::Function, + OpenAIChatMessageRole::System => ChatMessageRole::System, + OpenAIChatMessageRole::Developer => ChatMessageRole::System, + OpenAIChatMessageRole::Tool => ChatMessageRole::Function, + OpenAIChatMessageRole::User => ChatMessageRole::User, + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAITextContentType { + Text, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAITextContent { + #[serde(rename = "type")] + pub r#type: OpenAITextContentType, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageUrlContent { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAIImageContentType { + ImageUrl, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageContent { + pub r#type: OpenAIImageContentType, + pub image_url: OpenAIImageUrlContent, +} + +// Define an enum for mixed content +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIContentBlock { + TextContent(OpenAITextContent), + ImageContent(OpenAIImageContent), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIChatMessageContent { + Structured(Vec), + String(String), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIChatMessage { + pub role: OpenAIChatMessageRole, + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAICompletionChatMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub role: OpenAIChatMessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAITopLogprob { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoiceLogprob { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +impl From for TopLogprob { + fn from(top_logprob: OpenAITopLogprob) -> Self { + TopLogprob { + token: top_logprob.token, + logprob: top_logprob.logprob, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoiceLogprobs { + pub content: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoice { + pub message: OpenAICompletionChatMessage, + pub index: usize, + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatCompletion { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub usage: Option, +} + +// This code performs a type conversion with information loss when converting to ChatFunctionCall. +// It only supports one tool call, so it takes the first one from the vector of OpenAIToolCall, +// hence potentially discarding other tool calls. +impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &OpenAICompletionChatMessage) -> Result { + let role = ChatMessageRole::from(cm.role.clone()); + let content = match cm.content.as_ref() { + Some(c) => Some(c.clone()), + None => None, + }; + + let function_calls = if let Some(tool_calls) = cm.tool_calls.as_ref() { + let cfc = tool_calls + .into_iter() + .map(|tc| ChatFunctionCall::try_from(tc)) + .collect::, _>>()?; + + Some(cfc) + } else { + None + }; + + let function_call = if let Some(fcs) = function_calls.as_ref() { + match fcs.first() { + Some(fc) => Some(fc), + None => None, + } + .cloned() + } else { + None + }; + + let name = match cm.name.as_ref() { + Some(c) => Some(c.clone()), + None => None, + }; + + Ok(AssistantChatMessage { + content, + role, + name, + function_call, + function_calls, + }) + } +} + +impl TryFrom<&ContentBlock> for OpenAIChatMessageContent { + type Error = anyhow::Error; + + fn try_from(cm: &ContentBlock) -> Result { + match cm { + ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])), + ContentBlock::Mixed(m) => { + let content: Vec = m + .into_iter() + .map(|mb| match mb { + MixedContent::TextContent(tc) => { + Ok(OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: tc.text.clone(), + })) + } + MixedContent::ImageContent(ic) => { + Ok(OpenAIContentBlock::ImageContent(OpenAIImageContent { + r#type: OpenAIImageContentType::ImageUrl, + image_url: OpenAIImageUrlContent { + url: ic.image_url.url.clone(), + }, + })) + } + }) + .collect::>>()?; + + Ok(OpenAIChatMessageContent::Structured(content)) + } + } + } +} + +impl TryFrom<&String> for OpenAIChatMessageContent { + type Error = anyhow::Error; + + fn try_from(t: &String) -> Result { + Ok(OpenAIChatMessageContent::Structured(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])) + } +} + +impl TryFrom<&ChatMessage> for OpenAIChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &ChatMessage) -> Result { + match cm { + ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { + content: match &assistant_msg.content { + Some(c) => Some(OpenAIChatMessageContent::try_from(c)?), + None => None, + }, + name: assistant_msg.name.clone(), + role: OpenAIChatMessageRole::from(&assistant_msg.role), + tool_calls: match assistant_msg.function_calls.as_ref() { + Some(fc) => Some( + fc.into_iter() + .map(|f| OpenAIToolCall::try_from(f)) + .collect::, _>>()?, + ), + None => None, + }, + tool_call_id: None, + }), + ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?), + name: None, + role: OpenAIChatMessageRole::Tool, + tool_calls: None, + tool_call_id: Some(function_msg.function_call_id.clone()), + }), + ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?), + name: None, + role: OpenAIChatMessageRole::from(&system_msg.role), + tool_calls: None, + tool_call_id: None, + }), + ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?), + name: user_msg.name.clone(), + role: OpenAIChatMessageRole::from(&user_msg.role), + tool_calls: None, + tool_call_id: None, + }), + } + } +} + +impl TryFrom<&ChatFunction> for OpenAITool { + type Error = anyhow::Error; + + fn try_from(f: &ChatFunction) -> Result { + Ok(OpenAITool { + r#type: OpenAIToolType::Function, + function: OpenAIToolFunction { + name: f.name.clone(), + description: f.description.clone(), + parameters: f.parameters.clone(), + }, + }) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatDelta { + pub delta: Value, + pub index: usize, + pub finish_reason: Option, + pub logprobs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InnerError { + pub message: String, + #[serde(alias = "type")] + pub _type: String, + pub param: Option, + pub internal_message: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIError { + pub error: InnerError, +} + +pub struct OpenAICompatibleError { + pub provider: String, + pub error: OpenAIError, +} + +impl OpenAIError { + pub fn with_provider(self, provider: &str) -> OpenAICompatibleError { + OpenAICompatibleError { + provider: provider.to_string(), + error: self, + } + } + + pub fn retryable(&self) -> bool { + match self.error._type.as_str() { + "requests" => true, + "server_error" => match &self.error.internal_message { + Some(message) if message.contains("retry") => true, + _ => false, + }, + _ => false, + } + } + + pub fn retryable_streamed(&self, status: StatusCode) -> bool { + if status == StatusCode::TOO_MANY_REQUESTS { + return true; + } + if status.is_server_error() { + return true; + } + match self.error._type.as_str() { + "server_error" => match self.error.internal_message { + Some(_) => true, + None => false, + }, + _ => false, + } + } +} + +impl OpenAICompatibleError { + pub fn message(&self) -> String { + match &self.error.error.internal_message { + Some(ref msg) => format!( + "{}Error: [{}] {} internal_message={}", + self.provider, self.error.error._type, self.error.error.message, msg, + ), + None => format!( + "{}Error: [{}] {}", + self.provider, self.error.error._type, self.error.error.message, + ), + } + } + + pub fn retryable(&self) -> bool { + self.error.retryable() + } + + pub fn retryable_streamed(&self, status: StatusCode) -> bool { + self.error.retryable_streamed(status) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransformSystemMessages { + Remove, + ReplaceWithDeveloper, + Keep, +} + +pub async fn openai_compatible_chat_completion( + uri: Uri, + model_id: String, + api_key: String, + messages: &Vec, + functions: &Vec, + function_call: Option, + temperature: f32, + top_p: Option, + n: usize, + stop: &Vec, + mut max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + logprobs: Option, + top_logprobs: Option, + openai_extras: Option, + event_sender: Option>, + disable_provider_streaming: bool, + transform_system_messages: TransformSystemMessages, + provider_name: String, + squash_text_contents: bool, +) -> Result { + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = None; + } + } + + let (openai_org_id, openai_user, response_format, reasoning_effort) = match &openai_extras { + None => (None, None, None, None), + Some(v) => ( + match v.get("openai_organization_id") { + Some(Value::String(o)) => Some(o.to_string()), + _ => None, + }, + match v.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + match v.get("response_format") { + Some(Value::String(f)) => Some(f.to_string()), + _ => None, + }, + match v.get("reasoning_effort") { + Some(Value::String(r)) => Some(r.to_string()), + _ => None, + }, + ), + }; + + let tool_choice = match function_call.as_ref() { + Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), + None => None, + }; + + let tools = functions + .iter() + .map(OpenAITool::try_from) + .collect::, _>>()?; + + let openai_messages = + to_openai_messages(messages, transform_system_messages, squash_text_contents)?; + + let stream_output = event_sender.is_some(); + + let (c, request_id) = if !disable_provider_streaming && stream_output { + streamed_chat_completion( + uri, + api_key, + openai_org_id, + Some(model_id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + response_format, + reasoning_effort, + logprobs, + top_logprobs, + openai_user, + event_sender.clone(), + provider_name, + ) + .await? + } else { + chat_completion( + uri, + api_key, + openai_org_id, + Some(model_id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + response_format, + reasoning_effort, + logprobs, + top_logprobs, + openai_user, + provider_name, + ) + .await? + }; + + // We support streaming the output in the event sender, even if we're not + // using streaming from the provider. + if stream_output && disable_provider_streaming { + let sender = event_sender.as_ref().unwrap(); + for choice in &c.choices { + if let Some(content) = &choice.message.content { + // Split content into smaller chunks to simulate streaming. + for chunk in content + .chars() + .collect::>() + .chunks(4) + .map(|c| c.iter().collect::()) + { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": chunk, + }, + })); + // Add a small delay to simulate real-time streaming. + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + } + } + } + + assert!(c.choices.len() > 0); + + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::OpenAI.to_string(), + model: model_id.clone(), + completions: c + .choices + .iter() + .map(|c| AssistantChatMessage::try_from(&c.message)) + .collect::>>()?, + usage: c.usage.map(|usage| LLMTokenUsage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens.unwrap_or(0), + }), + provider_request_id: request_id, + logprobs: logprobs_from_choices(&c.choices), + }) +} + +fn to_openai_messages( + messages: &Vec, + transform_system_messages: TransformSystemMessages, + squash_text_contents: bool, +) -> Result, anyhow::Error> { + let mut oai_messages = messages + .into_iter() + // First convert to OpenAI chat messages. + .map(|m| OpenAIChatMessage::try_from(m)) + .collect::>>()? + .into_iter() + // Decide which content format to use for each message (structured or string). + // If there are images, we need to use structured format. + // If there is a single text content, we can always use string format (equivalent and compatible everywhere). + // Otherwise, if there are multiple text contents, we either squash them or keep the structured format, + // depending on the `squash_text_contents` flag. + .map(|m| match m.content { + None => m, + Some(OpenAIChatMessageContent::String(_)) => m, + Some(OpenAIChatMessageContent::Structured(contents)) => { + let all_contents_are_text = contents + .iter() + .all(|c| matches!(c, OpenAIContentBlock::TextContent(_))); + + OpenAIChatMessage { + role: m.role, + name: m.name, + tool_call_id: m.tool_call_id, + tool_calls: m.tool_calls, + content: match (contents.len(), all_contents_are_text, squash_text_contents) { + // Case 0: there's no content => return None + (0, _, _) => None, + // Case 1: there's only a single text content => use string format + (1, true, _) => Some(OpenAIChatMessageContent::String( + match contents.into_iter().next().unwrap() { + OpenAIContentBlock::TextContent(tc) => tc.text.clone(), + _ => unreachable!(), + }, + )), + // Case 2: There's more than one content, all contents are text and we want to squash them => squash them + (_, true, true) => Some(OpenAIChatMessageContent::String( + contents + .into_iter() + .map(|c| match c { + OpenAIContentBlock::TextContent(tc) => tc.text.clone(), + _ => unreachable!(), + }) + .collect::>() + .join("\n"), + )), + // Case 3: there's more than one content, the content isn't text or we don't want to squash them => keep structured format + (_, _, _) => Some(OpenAIChatMessageContent::Structured(contents)), + }, + } + } + }) + // Truncate the tool_ids to 40 characters. + .map(|m| { + fn truncate_id(id: Option) -> Option { + id.map(|id| id.chars().take(40).collect::()) + } + OpenAIChatMessage { + role: m.role, + name: m.name, + tool_call_id: truncate_id(m.tool_call_id), + tool_calls: m.tool_calls.map(|tool_calls| { + tool_calls + .into_iter() + .map(|tc| OpenAIToolCall { + id: truncate_id(tc.id), + function: tc.function, + r#type: tc.r#type, + }) + .collect() + }), + content: m.content, + } + }) + // Remove system messages if requested. + // Some models don't support system messages, so we need this to be + // configurable. + .filter(|m| { + transform_system_messages != TransformSystemMessages::Remove + || m.role != OpenAIChatMessageRole::System + }) + .collect::>(); + + // Replace system messages with developer messages if requested. + // Some newer models no longer support "system" as a role and support a "developer" role. + for m in oai_messages.iter_mut() { + if m.role == OpenAIChatMessageRole::System + && transform_system_messages == TransformSystemMessages::ReplaceWithDeveloper + { + m.role = OpenAIChatMessageRole::Developer; + } + } + + Ok(oai_messages) +} + +async fn streamed_chat_completion( + uri: Uri, + api_key: String, + organization_id: Option, + model_id: Option, + messages: &Vec, + tools: Vec, + tool_choice: Option, + temperature: f32, + top_p: f32, + n: usize, + stop: &Vec, + max_tokens: Option, + presence_penalty: f32, + frequency_penalty: f32, + response_format: Option, + reasoning_effort: Option, + logprobs: Option, + top_logprobs: Option, + user: Option, + event_sender: Option>, + provider_name: String, +) -> Result<(OpenAIChatCompletion, Option)> { + let url = uri.to_string(); + + let mut builder = match es::ClientBuilder::for_url(url.as_str()) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.method(String::from("POST")).header( + "Authorization", + format!("Bearer {}", api_key.clone()).as_str(), + ) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.header("Content-Type", "application/json") { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.header("api-key", api_key.clone().as_str()) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + + if let Some(org_id) = organization_id { + builder = builder + .header("OpenAI-Organization", org_id.as_str()) + .map_err(|_| { + anyhow!(format!( + "Error creating streamed client to {}", + provider_name + )) + })?; + } + + let mut body = json!({ + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "n": n, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "stream": true, + "stream_options": HashMap::from([("include_usage", true)]), + }); + if user.is_some() { + body["user"] = json!(user); + } + if model_id.is_some() { + body["model"] = json!(model_id); + } + if let Some(mt) = max_tokens { + body["max_tokens"] = mt.into(); + } + if !stop.is_empty() { + body["stop"] = json!(stop); + } + + if tools.len() > 0 { + body["tools"] = json!(tools); + } + if let Some(tool_choice) = tool_choice { + body["tool_choice"] = json!(tool_choice); + } + if let Some(response_format) = response_format { + body["response_format"] = json!({ + "type": response_format, + }); + } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } + if let Some(logprobs) = logprobs { + body["logprobs"] = json!(logprobs); + } + if let Some(top_logprobs) = top_logprobs { + body["top_logprobs"] = json!(top_logprobs); + } + + let client = builder + .body(body.to_string()) + .reconnect( + es::ReconnectOptions::reconnect(true) + .retry_initial(false) + .delay(Duration::from_secs(1)) + .backoff_factor(2) + .delay_max(Duration::from_secs(8)) + .build(), + ) + .build(); + + let mut stream = client.stream(); + + let chunks: Arc>> = Arc::new(Mutex::new(Vec::new())); + let mut usage = None; + let mut request_id: Option = None; + + 'stream: loop { + match stream.try_next().await { + Ok(e) => match e { + Some(es::SSE::Connected((_, headers))) => { + request_id = match headers.get("x-request-id") { + Some(v) => Some(v.to_string()), + None => None, + }; + } + Some(es::SSE::Comment(_)) => { + println!("UNEXPECTED COMMENT"); + } + Some(es::SSE::Event(e)) => match e.data.as_str() { + "[DONE]" => { + break 'stream; + } + _ => { + let index = { + let guard = chunks.lock(); + guard.len() + }; + + let chunk: ChatChunk = match serde_json::from_str(e.data.as_str()) { + Ok(c) => c, + Err(err) => { + let error: Result = + serde_json::from_str(e.data.as_str()); + match error { + Ok(error) => { + match error.retryable_streamed(StatusCode::OK) && index == 0 + { + true => Err(ModelError { + request_id: request_id.clone(), + message: error + .with_provider(&provider_name) + .message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + })?, + false => Err(ModelError { + request_id: request_id.clone(), + message: error + .with_provider(&provider_name) + .message(), + retryable: None, + })?, + } + break 'stream; + } + Err(_) => { + Err(anyhow!(format!( + "{}Error: failed parsing streamed \ + completion from {} err={} data={}", + provider_name, + provider_name, + err, + e.data.as_str(), + )))?; + break 'stream; + } + } + } + }; + + // Store usage + match &chunk.usage { + Some(received_usage) => { + usage = Some(received_usage.clone()); + } + None => (), + }; + + // Only stream if choices is length 1 but should always be the case. + match event_sender.as_ref() { + Some(sender) => { + if chunk.choices.len() == 1 { + // we ignore the role for generating events + + // If we get `content` in the delta object we stream "tokens". + match chunk.choices[0].delta.get("content") { + None => (), + Some(content) => match content.as_str() { + None => (), + Some(s) => { + if s.len() > 0 { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": s, + }, + })); + } + } + }, + }; + + // Emit a `function_call` event per tool_call. + if let Some(tool_calls) = chunk.choices[0] + .delta + .get("tool_calls") + .and_then(|v| v.as_array()) + { + tool_calls.iter().for_each(|tool_call| { + match tool_call.get("function") { + Some(f) => { + if let Some(Value::String(name)) = f.get("name") + { + let _ = sender.send(json!({ + "type": "function_call", + "content": { + "name": name, + }, + })); + } + } + _ => (), + } + }); + } + } + } + None => (), + }; + + if !chunk.choices.is_empty() { + chunks.lock().push(chunk); + } + } + }, + None => { + println!("UNEXPECTED NONE"); + break 'stream; + } + }, + Err(e) => { + match e { + es::Error::UnexpectedResponse(r) => { + let status = StatusCode::from_u16(r.status())?; + let headers = r.headers()?; + let request_id = match headers.get("x-request-id") { + Some(v) => Some(v.to_string()), + None => None, + }; + let b = r.body_bytes().await?; + + let error: Result = serde_json::from_slice(&b); + match error { + Ok(error) => { + match error.retryable_streamed(status) { + true => Err(ModelError { + request_id, + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + }), + false => Err(ModelError { + request_id, + message: error.with_provider(&provider_name).message(), + retryable: None, + }), + } + }?, + Err(_) => { + Err(anyhow!( + "Error streaming tokens from {}: status={} data={}", + &provider_name, + status, + String::from_utf8_lossy(&b) + ))?; + } + } + } + _ => { + Err(anyhow!( + "Error streaming tokens from {}: {:?}", + &provider_name, + e + ))?; + } + } + break 'stream; + } + } + } + + let mut completion = { + let guard = chunks.lock(); + let f = match guard.len() { + 0 => Err(anyhow!("No chunks received from {}", provider_name)), + _ => Ok(guard[0].clone()), + }?; + + // merge logprobs from all choices of all chunks + let logprobs: Vec = guard + .iter() + .flat_map(|chunk| { + chunk + .choices + .iter() + .filter_map(|choice| choice.logprobs.as_ref().map(|lp| lp.content.clone())) + }) + .flatten() + .collect(); + + let mut c = OpenAIChatCompletion { + id: f.id.clone(), + object: f.object.clone(), + created: f.created, + choices: f + .choices + .iter() + .map(|c| OpenAIChatChoice { + message: OpenAICompletionChatMessage { + content: Some("".to_string()), + name: None, + role: OpenAIChatMessageRole::System, + tool_calls: None, + tool_call_id: None, + }, + index: c.index, + finish_reason: None, + logprobs: match logprobs.len() { + 0 => None, + _ => Some(OpenAIChatChoiceLogprobs { + content: logprobs.clone(), + }), + }, + }) + .collect::>(), + usage, + }; + + for i in 0..guard.len() { + let a = guard[i].clone(); + if a.choices.len() != f.choices.len() { + Err(anyhow!("Inconsistent number of choices in streamed chunks"))?; + } + for j in 0..a.choices.len() { + match a.choices.get(j).unwrap().finish_reason.clone() { + None => (), + Some(f) => c.choices[j].finish_reason = Some(f), + }; + + match a.choices[j].delta.get("role") { + None => (), + Some(role) => match role.as_str() { + None => (), + Some(r) => { + c.choices[j].message.role = OpenAIChatMessageRole::from_str(r)?; + } + }, + }; + + match a.choices[j].delta.get("content") { + None => (), + Some(content) => match content.as_str() { + None => (), + Some(s) => { + c.choices[j].message.content = Some(format!( + "{}{}", + c.choices[j] + .message + .content + .as_ref() + .unwrap_or(&String::new()), + s + )); + } + }, + }; + + if let Some(tool_calls) = a.choices[j] + .delta + .get("tool_calls") + .and_then(|v| v.as_array()) + { + for tool_call in tool_calls { + match ( + tool_call.get("type").and_then(|v| v.as_str()), + tool_call.get("id").and_then(|v| v.as_str()), + tool_call.get("function"), + ) { + (Some("function"), Some(id), Some(f)) => { + if let Some(Value::String(name)) = f.get("name") { + c.choices[j] + .message + .tool_calls + .get_or_insert_with(Vec::new) + .push(OpenAIToolCall { + id: Some(id.to_string()), + r#type: OpenAIToolType::Function, + function: OpenAIToolCallFunction { + name: name.clone(), + arguments: String::new(), + }, + }); + } + } + (None, None, Some(f)) => { + if let (Some(Value::Number(idx)), Some(Value::String(a))) = + (tool_call.get("index"), f.get("arguments")) + { + let index: usize = idx + .as_u64() + .ok_or_else(|| anyhow!("Missing index for tools"))? + .try_into() + .map_err(|e| { + anyhow!("Invalid index value for tools: {:?}", e) + })?; + + let tool_calls = c.choices[j] + .message + .tool_calls + .as_mut() + .ok_or(anyhow!("Missing tool calls"))?; + + if index >= tool_calls.len() { + return Err(anyhow!( + "Index out-of-bound for tool_calls: {}", + index + )); + } + + tool_calls[index].function.arguments += a; + } + } + _ => (), + } + } + } + } + } + c + }; + + // for all messages, edit the content and strip leading and trailing spaces and \n + for m in completion.choices.iter_mut() { + m.message.content = match m.message.content.as_ref() { + None => None, + Some(c) => Some(c.trim().to_string()), + }; + } + + Ok((completion, request_id)) +} + +async fn chat_completion( + uri: Uri, + api_key: String, + organization_id: Option, + model_id: Option, + messages: &Vec, + tools: Vec, + tool_choice: Option, + temperature: f32, + top_p: f32, + n: usize, + stop: &Vec, + max_tokens: Option, + presence_penalty: f32, + frequency_penalty: f32, + response_format: Option, + reasoning_effort: Option, + logprobs: Option, + top_logprobs: Option, + user: Option, + provider_name: String, +) -> Result<(OpenAIChatCompletion, Option)> { + let mut body = json!({ + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "n": n, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + }); + if user.is_some() { + body["user"] = json!(user); + } + if let Some(model_id) = model_id { + body["model"] = json!(model_id); + } + if let Some(mt) = max_tokens { + body["max_tokens"] = mt.into(); + } + if !stop.is_empty() { + body["stop"] = json!(stop); + } + + if let Some(response_format) = response_format { + body["response_format"] = json!({ + "type": response_format, + }); + } + if tools.len() > 0 { + body["tools"] = json!(tools); + } + if let Some(tool_choice) = tool_choice { + body["tool_choice"] = json!(tool_choice); + } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } + if let Some(logprobs) = logprobs { + body["logprobs"] = json!(logprobs); + } + if let Some(top_logprobs) = top_logprobs { + body["top_logprobs"] = json!(top_logprobs); + } + + let mut req = reqwest::Client::new() + .post(uri.to_string()) + .header("Content-Type", "application/json") + // This one is for `openai`. + .header("Authorization", format!("Bearer {}", api_key.clone())) + // This one is for `azure_openai`. + .header("api-key", api_key.clone()); + + if let Some(organization_id) = organization_id { + req = req.header( + "OpenAI-Organization", + &format!("{}", organization_id.clone()), + ); + } + + let req = req.json(&body); + + let res = match timeout(Duration::new(180, 0), req.send()).await { + Ok(Ok(res)) => res, + Ok(Err(e)) => Err(e)?, + Err(_) => Err(anyhow!(format!( + "Timeout sending request to {} after 180s", + provider_name + )))?, + }; + + let res_headers = res.headers(); + let request_id = match res_headers.get("x-request-id") { + Some(request_id) => Some(request_id.to_str()?.to_string()), + None => None, + }; + + let body = match timeout(Duration::new(180, 0), res.bytes()).await { + Ok(Ok(body)) => body, + Ok(Err(e)) => Err(e)?, + Err(_) => Err(anyhow!(format!( + "Timeout reading response from {} after 180s", + provider_name + )))?, + }; + + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let mut completion: OpenAIChatCompletion = match serde_json::from_slice(c) { + Ok(c) => Ok(c), + Err(_) => { + let error: OpenAIError = serde_json::from_slice(c)?; + match error.retryable() { + true => Err(ModelError { + request_id: request_id.clone(), + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + }), + false => Err(ModelError { + request_id: request_id.clone(), + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), + }), + } + } + }?; + + // for all messages, edit the content and strip leading and trailing spaces and \n + for m in completion.choices.iter_mut() { + m.message.content = match m.message.content.as_ref() { + None => None, + Some(c) => Some(c.trim().to_string()), + }; + } + + Ok((completion, request_id)) +} + +fn logprobs_from_choices(choices: &Vec) -> Option> { + let lp: Vec = choices + .iter() + .filter_map(|choice| choice.logprobs.as_ref()) + .flat_map(|lp| { + lp.content.iter().map(|content_logprob| LLMChatLogprob { + token: content_logprob.token.clone(), + logprob: content_logprob.logprob, + top_logprobs: match content_logprob.top_logprobs.len() { + 0 => None, + _ => Some( + content_logprob + .top_logprobs + .iter() + .map(|top_logprob| top_logprob.clone().into()) + .collect(), + ), + }, + }) + }) + .collect(); + + match lp.len() { + 0 => None, + _ => Some(lp), + } +} diff --git a/core/src/providers/togetherai.rs b/core/src/providers/togetherai.rs index 7a3911da98f7..8f77ed9e0186 100644 --- a/core/src/providers/togetherai.rs +++ b/core/src/providers/togetherai.rs @@ -1,12 +1,7 @@ -use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::Embedder; use crate::providers::llm::ChatFunction; -use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::{ - chat_completion, logprobs_from_choices, streamed_chat_completion, to_openai_messages, - OpenAIChatMessage, OpenAIChatMessageContent, OpenAIContentBlock, OpenAITextContent, - OpenAITextContentType, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLM}; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; @@ -18,10 +13,13 @@ use async_trait::async_trait; use hyper::Uri; use parking_lot::RwLock; use serde_json::Value; -use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, TransformSystemMessages, +}; + pub struct TogetherAILLM { id: String, api_key: Option, @@ -115,7 +113,7 @@ impl LLM for TogetherAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -123,138 +121,30 @@ impl LLM for TogetherAILLM { _extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - // TogetherAI doesn't work with the new chat message content format. - // We have to modify the messages contents to use the "String" format. - let openai_messages = to_openai_messages(messages, &self.id)? - .into_iter() - .filter_map(|m| match m.content { - None => Some(m), - Some(OpenAIChatMessageContent::String(_)) => Some(m), - Some(OpenAIChatMessageContent::Structured(contents)) => { - // Find the first text content, and use it to make a string content. - let content = contents.into_iter().find_map(|c| match c { - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text, - .. - }) => Some(OpenAIChatMessageContent::String(text)), - _ => None, - }); - - Some(OpenAIChatMessage { - role: m.role, - name: m.name, - tool_call_id: m.tool_call_id, - tool_calls: m.tool_calls, - content, - }) - } - }) - .collect::>(); - - let is_streaming = event_sender.is_some(); - - let (c, request_id) = if is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - ) - .await? - }; - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::OpenAI.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + None, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "TogetherAI".to_string(), + true, // squash text contents (togetherai doesn't support structured messages) + ) + .await } }