diff --git a/core/src/providers/google_ai_studio.rs b/core/src/providers/google_ai_studio.rs index fbbaffb027c5..1519215840ff 100644 --- a/core/src/providers/google_ai_studio.rs +++ b/core/src/providers/google_ai_studio.rs @@ -1,309 +1,26 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use eventsource_client as es; -use eventsource_client::Client as ESClient; -use futures::TryStreamExt; -use hyper::StatusCode; -use parking_lot::{Mutex, RwLock}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; + +use http::Uri; +use parking_lot::RwLock; +use serde_json::Value; use std::sync::Arc; -use std::time::Duration; + use tokio::sync::mpsc::UnboundedSender; -use crate::{ - providers::{ - chat_messages::AssistantChatMessage, - llm::Tokens, - provider::{ModelError, ModelErrorRetryOptions}, - }, - run::Credentials, - utils, -}; +use crate::{run::Credentials, utils}; use super::{ - chat_messages::{ChatMessage, ContentBlock, FunctionChatMessage, MixedContent}, + chat_messages::ChatMessage, embedder::Embedder, - llm::{ - ChatFunction, ChatFunctionCall, ChatMessageRole, LLMChatGeneration, LLMGeneration, - LLMTokenUsage, LLM, - }, + llm::{ChatFunction, LLMChatGeneration, LLMGeneration, LLM}, + openai_compatible_helpers::{openai_compatible_chat_completion, TransformSystemMessages}, provider::{Provider, ProviderID}, tiktoken::tiktoken::{ batch_tokenize_async, cl100k_base_singleton, decode_async, encode_async, CoreBPE, }, }; -// Disabled for now as it requires using a "tools" API which we don't support yet. -pub const USE_FUNCTION_CALLING: bool = false; - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct UsageMetadata { - prompt_token_count: Option, - candidates_token_count: Option, - total_token_count: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAiStudioFunctionResponseContent { - name: String, - content: String, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionResponse { - name: String, - response: GoogleAiStudioFunctionResponseContent, -} - -impl TryFrom<&FunctionChatMessage> for GoogleAIStudioFunctionResponse { - type Error = anyhow::Error; - - fn try_from(m: &FunctionChatMessage) -> Result { - let name = m.name.clone().unwrap_or_default(); - Ok(GoogleAIStudioFunctionResponse { - name: name.clone(), - response: GoogleAiStudioFunctionResponseContent { - name, - content: m.content.clone(), - }, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionCall { - name: String, - args: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionDeclaration { - name: String, - description: String, - parameters: Option, -} - -impl TryFrom<&ChatFunction> for GoogleAIStudioFunctionDeclaration { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunction) -> Result { - Ok(GoogleAIStudioFunctionDeclaration { - name: f.name.clone(), - description: f.description.clone().unwrap_or_else(|| String::from("")), - parameters: match f.parameters.clone() { - // The API rejects empty 'properties'. If 'properties' is empty, return None. - // Otherwise, return the object wrapped in Some. - Some(serde_json::Value::Object(obj)) => { - if obj.get("properties").map_or(false, |props| { - props.as_object().map_or(false, |p| p.is_empty()) - }) { - None - } else { - Some(serde_json::Value::Object(obj)) - } - } - p => p, - }, - }) - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "UPPERCASE")] -pub enum GoogleAIStudioTooConfigMode { - Auto, - Any, - None, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode, - allowed_function_names: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Part { - text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_response: Option, -} - -impl TryFrom<&ChatFunctionCall> for GoogleAIStudioFunctionCall { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunctionCall) -> Result { - let args = match serde_json::from_str(f.arguments.as_str()) { - Ok(v) => v, - Err(_) => Err(anyhow!( - "GoogleAISudio function call arguments must be valid JSON" - ))?, - }; - Ok(GoogleAIStudioFunctionCall { - name: f.name.clone(), - args, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Content { - role: Option, - parts: Option>, -} - -impl TryFrom<&ChatMessage> for Content { - type Error = anyhow::Error; - - fn try_from(cm: &ChatMessage) -> Result { - match cm { - ChatMessage::Assistant(assistant_msg) => { - let parts = match assistant_msg.function_calls { - Some(ref fcs) => fcs - .iter() - .map(|fc| { - Ok(Part { - text: assistant_msg.content.clone(), - function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), - function_response: None, - }) - }) - .collect::, anyhow::Error>>()?, - None => { - if let Some(ref fc) = assistant_msg.function_call { - vec![Part { - text: assistant_msg.content.clone(), - function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), - function_response: None, - }] - } else { - vec![Part { - text: assistant_msg.content.clone(), - function_call: None, - function_response: None, - }] - } - } - }; - - Ok(Content { - role: Some(String::from("model")), - parts: Some(parts), - }) - } - ChatMessage::Function(function_msg) => Ok(Content { - role: Some(String::from("function")), - parts: Some(vec![Part { - text: None, - function_call: None, - function_response: GoogleAIStudioFunctionResponse::try_from(function_msg).ok(), - }]), - }), - ChatMessage::User(user_msg) => { - let text = match &user_msg.content { - ContentBlock::Mixed(m) => { - let result = m.iter().enumerate().try_fold( - String::new(), - |mut acc, (i, content)| { - match content { - MixedContent::ImageContent(_) => Err(anyhow!( - "Vision is not supported for Google AI Studio." - )), - MixedContent::TextContent(tc) => { - acc.push_str(&tc.text.trim()); - if i != m.len() - 1 { - // Add newline if it's not the last item. - acc.push('\n'); - } - Ok(acc) - } - } - }, - ); - - match result { - Ok(text) if !text.is_empty() => Ok(text), - Ok(_) => Err(anyhow!("Text is required.")), // Empty string. - Err(e) => Err(e), - } - } - ContentBlock::Text(t) => Ok(t.clone()), - }?; - - Ok(Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - text: Some(text), - function_call: None, - function_response: None, - }]), - }) - } - ChatMessage::System(system_msg) => Ok(Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - // System is passed as a Content. We transform it here but it will be removed - // from the list of messages and passed as separate argument to the API. - text: Some(system_msg.content.clone()), - function_call: None, - function_response: None, - }]), - }), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Candidate { - content: Option, - finish_reason: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Completion { - candidates: Option>, - usage_metadata: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct InnerError { - pub message: String, - pub code: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioError { - pub error: InnerError, -} - -impl GoogleAIStudioError { - pub fn message(&self) -> String { - format!("GoogleAIStudio: {}", self.error.message) - } - - pub fn retryable(&self) -> bool { - return false; - } - - pub fn retryable_streamed(&self, status: StatusCode) -> bool { - if status == StatusCode::TOO_MANY_REQUESTS { - return true; - } - if status.is_server_error() { - return true; - } - return false; - } -} - pub struct GoogleAiStudioProvider {} impl GoogleAiStudioProvider { @@ -349,11 +66,8 @@ impl GoogleAiStudioLLM { Self { id, api_key: None } } - fn model_endpoint(&self) -> String { - format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse", - self.id - ) + fn model_endpoint(&self) -> Uri { + Uri::from_static("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions") } fn tokenizer(&self) -> Arc> { @@ -397,104 +111,19 @@ impl LLM for GoogleAiStudioLLM { async fn generate( &self, - prompt: &str, - mut max_tokens: Option, - temperature: f32, - n: usize, - stop: &Vec, - presence_penalty: Option, - frequency_penalty: Option, - top_p: Option, - top_logprobs: Option, + _prompt: &str, + _max_tokens: Option, + _temperature: f32, + _n: usize, + _stop: &Vec, + _presence_penalty: Option, + _frequency_penalty: Option, + _top_p: Option, + _top_logprobs: Option, _extras: Option, - event_sender: Option>, + _event_sender: Option>, ) -> Result { - assert!(n == 1); - - let api_key = match &self.api_key { - Some(k) => k.to_string(), - None => Err(anyhow!("API key not found"))?, - }; - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - if top_logprobs.is_some() { - Err(anyhow!("Top logprobs not supported by GoogleAIStudio"))?; - } - - if let Some(m) = max_tokens { - if m == -1 { - let tokens = self.encode(prompt).await?; - max_tokens = Some((self.context_size() - tokens.len()) as i32); - } - } - - let uri = self.model_endpoint(); - - let c = streamed_chat_completion( - uri, - api_key, - &vec![Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - text: Some(String::from(prompt)), - function_call: None, - function_response: None, - }]), - }], - vec![], - None, - None, - temperature, - stop, - max_tokens, - match top_p { - Some(t) => t, - None => 1.0, - }, - None, - event_sender, - false, - ) - .await?; - - Ok(LLMGeneration { - created: utils::now(), - provider: ProviderID::GoogleAiStudio.to_string(), - model: self.id().clone(), - completions: vec![Tokens { - // Get candidates?.[0]?.content?.parts?.[0]?.text ?? "". - text: c - .candidates - .as_ref() - .and_then(|c| c.first()) - .and_then(|c| c.content.as_ref()) - .and_then(|c| c.parts.as_ref()) - .and_then(|p| p.first()) - .and_then(|p| p.text.as_ref()) - .map(|t| t.to_string()) - .unwrap_or_else(|| String::from("")), - - tokens: Some(vec![]), - logprobs: Some(vec![]), - top_logprobs: None, - }], - prompt: Tokens { - text: prompt.to_string(), - tokens: Some(vec![]), - logprobs: Some(vec![]), - top_logprobs: None, - }, - usage: c.usage_metadata.map(|c| LLMTokenUsage { - prompt_tokens: c.prompt_token_count.unwrap_or(0) as u64, - completion_tokens: c.candidates_token_count.unwrap_or(0) as u64, - }), - provider_request_id: None, - }) + unimplemented!() } async fn chat( @@ -506,489 +135,37 @@ impl LLM for GoogleAiStudioLLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, - presence_penalty: Option, - frequency_penalty: Option, - _logprobs: Option, - _top_logprobs: Option, + max_tokens: Option, + _presence_penalty: Option, + _frequency_penalty: Option, + logprobs: Option, + top_logprobs: Option, _extras: Option, event_sender: Option>, ) -> Result { - assert!(n == 1); - - let api_key = match &self.api_key { - Some(k) => k.to_string(), - None => Err(anyhow!("API key not found"))?, - }; - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - - let uri = self.model_endpoint(); - - // Remove system message if first. - let system = match messages.get(0) { - Some(cm) => match cm { - ChatMessage::System(_) => Some(Content::try_from(cm)?), - _ => None, - }, - None => None, - }; - - let messages = messages - .iter() - .skip(match system.as_ref() { - Some(_) => 1, - None => 0, - }) - .map(|cm| Content::try_from(cm)) - .collect::>>()?; - - // TODO: backward comp for non alternated messages - - let tools = functions - .iter() - .map(GoogleAIStudioFunctionDeclaration::try_from) - .collect::, _>>()?; - - let tool_config = match function_call { - Some(fc) => Some(match fc.as_str() { - "auto" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Auto, - allowed_function_names: None, - }, - "none" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::None, - allowed_function_names: None, - }, - "any" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Any, - allowed_function_names: None, - }, - _ => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Any, - allowed_function_names: Some(vec![fc.clone()]), - }, - }), - None => None, - }; - - let c = streamed_chat_completion( - uri, - api_key, - &messages, - tools, - tool_config, - system, + openai_compatible_chat_completion( + self.model_endpoint(), + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, temperature, + top_p, + n, stop, max_tokens, - match top_p { - Some(t) => t, - None => 1.0, - }, + None, + None, + logprobs, + top_logprobs, None, event_sender, - false, - ) - .await?; - - let mut content: Option = None; - let mut function_calls: Vec = vec![]; - - // Get candidates?.[0]?.content?.parts. - if let Some(parts) = c - .candidates - .as_ref() - .and_then(|c| c.first()) - .and_then(|c| c.content.as_ref()) - .and_then(|c| c.parts.as_ref()) - { - for p in parts.iter() { - // If the part has text, either append it to the content if we already have some - // or set it as the content. - if let Some(t) = p.text.as_ref() { - content = content.map(|c| c + t).or_else(|| Some(t.clone())); - } - - // If the part has a function call, add it to the list of function calls. - if let Some(fc) = p.function_call.as_ref() { - function_calls.push(ChatFunctionCall { - id: format!("fc_{}", utils::new_id()[0..9].to_string()), - name: fc.name.clone(), - arguments: match fc.args { - Some(ref args) => serde_json::to_string(args)?, - None => String::from("{}"), - }, - }); - } - } - } - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::GoogleAiStudio.to_string(), - model: self.id().clone(), - completions: vec![AssistantChatMessage { - name: None, - function_call: match function_calls.first() { - Some(fc) => Some(fc.clone()), - None => None, - }, - function_calls: match function_calls.len() { - 0 => None, - _ => Some(function_calls), - }, - role: ChatMessageRole::Assistant, - content, - }], - usage: c.usage_metadata.map(|c| LLMTokenUsage { - prompt_tokens: c.prompt_token_count.unwrap_or(0) as u64, - completion_tokens: c.candidates_token_count.unwrap_or(0) as u64, - }), - provider_request_id: None, - logprobs: None, - }) - } -} - -pub async fn streamed_chat_completion( - uri: String, - api_key: String, - messages: &Vec, - tools: Vec, - tool_config: Option, - system_instruction: Option, - temperature: f32, - stop: &Vec, - max_tokens: Option, - top_p: f32, - top_k: Option, - event_sender: Option>, - use_header_auth: bool, -) -> Result { - let url = match use_header_auth { - true => uri.to_string(), - false => format!("{}&key={}", uri, api_key), - }; - - let mut builder = match es::ClientBuilder::for_url(url.as_str()) { - Ok(builder) => builder, - Err(e) => { - return Err(anyhow!( - "Error creating GoogleAIStudio streaming client: {:?}", - e - )) - } - }; - - if use_header_auth { - builder = match builder.method(String::from("POST")).header( - "Authorization", - format!("Bearer {}", api_key.clone()).as_str(), - ) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to GoogleAIStudio")), - }; - } - - builder = match builder.header("Content-Type", "application/json") { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to GoogleAIStudio")), - }; - - let mut body = json!({ - "contents": json!(messages), - "generation_config": { - "temperature": temperature, - "topP": top_p, - "topK": top_k, - "maxOutputTokens": max_tokens, - "stopSequences": match stop.len() { - 0 => None, - _ => Some(stop), - }, - }, - "safety_settings": [ - { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH" } - ] - }); - - if tools.len() > 0 { - body["tools"] = json!(vec![json!({ - "functionDeclarations": tools - })]); - } - - if tool_config.is_some() { - body["toolConfig"] = json!({ - "functionCallingConfig": tool_config - }); - } - - if system_instruction.is_some() { - body["systemInstruction"] = json!(system_instruction); - } - - let client = builder - .body(body.to_string()) - .method("POST".to_string()) - .reconnect( - es::ReconnectOptions::reconnect(true) - .retry_initial(false) - .delay(Duration::from_secs(1)) - .backoff_factor(2) - .delay_max(Duration::from_secs(8)) - .build(), + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "GoogleAIStudio".to_string(), + false, // don't squash text contents ) - .build(); - - let mut stream = client.stream(); - - let completions: Arc>> = Arc::new(Mutex::new(Vec::new())); - - 'stream: loop { - match stream.try_next().await { - Ok(e) => match e { - Some(es::SSE::Connected(_)) => { - // GoogleAISudio does not return a request id in headers. - // Nothing to do. - } - Some(es::SSE::Comment(_)) => { - println!("UNEXPECTED COMMENT"); - } - Some(es::SSE::Event(e)) => { - let completion: Completion = serde_json::from_str(e.data.as_str())?; - let completion_candidates = completion.candidates.clone().unwrap_or_default(); - - match completion_candidates.len() { - 0 => { - break 'stream; - } - 1 => (), - n => { - Err(anyhow!("Unexpected number of candidates: {}", n))?; - } - }; - - if let (Some(parts), Some(sender)) = ( - completion_candidates[0] - .content - .as_ref() - .and_then(|c| c.parts.as_ref()), - event_sender.as_ref(), - ) { - parts.iter().for_each(|p| { - match p.text { - Some(ref t) => { - if t.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": t, - } - })); - } - } - None => (), - } - - match p.function_call { - Some(ref f) => { - let _ = sender.send(json!({ - "type": "function_call", - "content": { - "name": f.name, - } - })); - } - None => (), - } - }); - } - - completions.lock().push(completion); - } - None => { - break 'stream; - } - }, - Err(e) => { - match e { - // Nothing to do, go direclty to break stream. - es::Error::Eof => (), - es::Error::UnexpectedResponse(r) => { - let status = StatusCode::from_u16(r.status())?; - // GogoleAIStudio currently has no request id in headers. - // let headers = r.headers()?; - // let request_id = match headers.get("request-id") { - // Some(v) => Some(v.to_string()), - // None => None, - // }; - let b = r.body_bytes().await?; - - let error: Result = serde_json::from_slice(&b); - match error { - Ok(error) => { - match error.retryable_streamed(status) { - true => Err(ModelError { - request_id: None, - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id: None, - message: error.message(), - retryable: None, - }), - } - }?, - Err(_) => Err(anyhow!( - "Error streaming tokens from GoogleAIStudio: status={} data={}", - status, - String::from_utf8_lossy(&b) - ))?, - } - } - _ => { - Err(anyhow!( - "Error streaming tokens from GoogleAIStudio: {:?}", - e - ))?; - break 'stream; - } - } - break 'stream; - } - } + .await } - - let completions_lock = completions.lock(); - - // Sometimes (usually when last message is Assistant), the AI decides not to respond. - if completions_lock.len() == 0 { - return Ok(Completion { - candidates: None, - usage_metadata: None, - }); - } - - let mut usage_metadata: Option = None; - - let mut full_candidate = Candidate { - content: Some(Content { - role: Some(String::from("MODEL")), - parts: Some(vec![]), - }), - finish_reason: None, - }; - - let mut text_parts: Option = None; - let mut function_call_parts: Vec = vec![]; - - for c in completions_lock.iter() { - match &c.usage_metadata { - None => (), - Some(usage) => { - usage_metadata = Some(usage.clone()); - } - } - - // Check that we don't have more than one candidate. - match c - .candidates - .as_ref() - .map(|candidates| candidates.len()) - .unwrap_or_default() - { - 0 => (), - 1 => (), - n => Err(anyhow!("Unexpected number of candidates >1: {}", n))?, - } - - if let Some(candidate) = c.candidates.as_ref().map(|c| c.first()).flatten() { - // Validate that the role (if any) is MODEL. - - if let Some(c) = candidate.content.as_ref() { - match &c.role { - Some(r) => match r.to_uppercase().as_str() { - "MODEL" => (), - r => Err(anyhow!("Unexpected role in completion: {}", r))?, - }, - None => (), - } - } - - if let Some(r) = candidate.finish_reason.as_ref() { - full_candidate.finish_reason = Some(r.clone()); - } - - if let Some(parts) = candidate.content.as_ref().and_then(|c| c.parts.as_ref()) { - for p in parts.iter() { - if let Some(t) = p.text.as_ref() { - match text_parts.as_mut() { - Some(tp) => { - tp.text = Some(tp.text.clone().unwrap_or_default() + t.as_str()); - } - None => { - text_parts = Some(p.clone()); - } - } - } - - if p.function_call.is_some() { - function_call_parts.push(p.clone()); - } - - if p.function_response.is_some() { - Err(anyhow!("Unexpected function response part in completion"))?; - } - } - } - } - } - - match full_candidate - .content - .as_mut() - .and_then(|c| c.parts.as_mut()) - { - Some(parts) => { - if let Some(tp) = text_parts { - parts.push(tp); - } - parts.extend(function_call_parts); - } - // This should never happen since we define the `full_candidate` above. - None => unreachable!(), - } - - Ok(Completion { - candidates: Some(vec![full_candidate]), - usage_metadata, - }) } diff --git a/core/src/providers/openai_compatible_helpers.rs b/core/src/providers/openai_compatible_helpers.rs index 04c4f0ccdfa5..f80ff001394d 100644 --- a/core/src/providers/openai_compatible_helpers.rs +++ b/core/src/providers/openai_compatible_helpers.rs @@ -139,13 +139,14 @@ impl TryFrom<&OpenAIToolCall> for ChatFunctionCall { type Error = anyhow::Error; fn try_from(tc: &OpenAIToolCall) -> Result { + // Some providers don't provide a function call ID (eg google_ai_studio) let id = tc .id - .as_ref() - .ok_or_else(|| anyhow!("Missing tool call id."))?; + .clone() + .unwrap_or(format!("fc_{}", utils::new_id()[0..9].to_string())); Ok(ChatFunctionCall { - id: id.clone(), + id, name: tc.function.name.clone(), arguments: tc.function.arguments.clone(), }) @@ -248,6 +249,7 @@ pub enum OpenAIChatMessageContent { #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct OpenAIChatMessage { pub role: OpenAIChatMessageRole, + #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, @@ -310,7 +312,7 @@ pub struct OpenAIChatChoice { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct OpenAIChatCompletion { - pub id: String, + pub id: Option, pub object: String, pub created: u64, pub choices: Vec, @@ -489,7 +491,7 @@ pub struct ChatDelta { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ChatChunk { - pub id: String, + pub id: Option, pub object: String, pub created: u64, pub model: String, @@ -665,14 +667,8 @@ pub async fn openai_compatible_chat_completion( n, stop, max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, + presence_penalty, + frequency_penalty, response_format, reasoning_effort, logprobs, @@ -699,14 +695,8 @@ pub async fn openai_compatible_chat_completion( n, stop, max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, + presence_penalty, + frequency_penalty, response_format, reasoning_effort, logprobs, @@ -876,8 +866,8 @@ async fn streamed_chat_completion( n: usize, stop: &Vec, max_tokens: Option, - presence_penalty: f32, - frequency_penalty: f32, + presence_penalty: Option, + frequency_penalty: Option, response_format: Option, reasoning_effort: Option, logprobs: Option, @@ -944,11 +934,15 @@ async fn streamed_chat_completion( "temperature": temperature, "top_p": top_p, "n": n, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, "stream": true, "stream_options": HashMap::from([("include_usage", true)]), }); + if let Some(presence_penalty) = presence_penalty { + body["presence_penalty"] = json!(presence_penalty); + } + if let Some(frequency_penalty) = frequency_penalty { + body["frequency_penalty"] = json!(frequency_penalty); + } if user.is_some() { body["user"] = json!(user); } @@ -1289,18 +1283,22 @@ async fn streamed_chat_completion( tool_call.get("id").and_then(|v| v.as_str()), tool_call.get("function"), ) { - (Some("function"), Some(id), Some(f)) => { + (Some("function"), id, Some(f)) => { if let Some(Value::String(name)) = f.get("name") { c.choices[j] .message .tool_calls .get_or_insert_with(Vec::new) .push(OpenAIToolCall { - id: Some(id.to_string()), + // Set None if id is empty + id: id.filter(|s| !s.is_empty()).map(|s| s.to_string()), r#type: OpenAIToolType::Function, function: OpenAIToolCallFunction { name: name.clone(), - arguments: String::new(), + arguments: match f.get("arguments") { + Some(Value::String(a)) => a.clone(), + _ => String::new(), + }, }, }); } @@ -1366,8 +1364,8 @@ async fn chat_completion( n: usize, stop: &Vec, max_tokens: Option, - presence_penalty: f32, - frequency_penalty: f32, + presence_penalty: Option, + frequency_penalty: Option, response_format: Option, reasoning_effort: Option, logprobs: Option, @@ -1380,9 +1378,13 @@ async fn chat_completion( "temperature": temperature, "top_p": top_p, "n": n, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, }); + if let Some(presence_penalty) = presence_penalty { + body["presence_penalty"] = json!(presence_penalty); + } + if let Some(frequency_penalty) = frequency_penalty { + body["frequency_penalty"] = json!(frequency_penalty); + } if user.is_some() { body["user"] = json!(user); }