diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 4796ed6c54ea..e90cd3ed0858 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -20,15 +20,16 @@ use std::io::prelude::*; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; -use super::llm::ChatFunction; +use super::llm::{ChatFunction, ChatFunctionCall}; use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub enum StopReason { - StopSequence, - MaxTokens, EndTurn, + MaxTokens, + StopSequence, + ToolUse, } #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] @@ -75,6 +76,86 @@ pub struct AnthropicContent { pub text: String, } +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +struct ToolUse { + id: String, + name: String, + input: Value, +} + +impl TryFrom<&ToolUse> for ChatFunctionCall { + type Error = anyhow::Error; + + fn try_from(tool_use: &ToolUse) -> Result { + let arguments = serde_json::to_string(&tool_use.input)?; + Ok(ChatFunctionCall { + name: tool_use.name.clone(), + arguments, + }) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case", tag = "type")] +enum AnthropicResponseContent { + Text { text: String }, + ToolUse(ToolUse), +} + +impl TryFrom for AnthropicResponseContent { + type Error = anyhow::Error; + + fn try_from(value: AnthropicStreamContent) -> Result { + match value.r#type.as_str() { + "text" => Ok(AnthropicResponseContent::Text { + text: value.text.to_string(), + }), + _ => Err(anyhow!("Type not supported in streaming.")), + } + } +} + +impl TryFrom<&AnthropicResponseContent> for ChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &AnthropicResponseContent) -> Result { + let role = ChatMessageRole::Assistant; + + let content = match cm.get_text() { + Some(c) => Some(c.clone()), + None => None, + }; + + let function_call = match cm.get_tool_use() { + Some(tc) => Some(ChatFunctionCall::try_from(tc)?), + None => None, + }; + + Ok(ChatMessage { + content, + role, + name: None, + function_call, + }) + } +} + +impl AnthropicResponseContent { + fn get_text(&self) -> Option<&String> { + match self { + AnthropicResponseContent::Text { text } => Some(text), + AnthropicResponseContent::ToolUse { .. } => None, + } + } + + fn get_tool_use(&self) -> Option<&ToolUse> { + match self { + AnthropicResponseContent::Text { .. } => None, + AnthropicResponseContent::ToolUse(tu) => Some(tu), + } + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct AnthropicChatMessage { pub content: Vec, @@ -114,22 +195,118 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage { } } +// Tools. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AnthropicTool { + pub name: String, + pub description: Option, + pub input_schema: Option, +} + +impl TryFrom<&ChatFunction> for AnthropicTool { + type Error = anyhow::Error; + + fn try_from(f: &ChatFunction) -> Result { + Ok(AnthropicTool { + name: f.name.clone(), + description: f.description.clone(), + input_schema: f.parameters.clone(), + }) + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Usage { pub input_tokens: u64, pub output_tokens: u64, } +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AnthropicStreamContent { + pub r#type: String, + pub text: String, +} + #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChatResponse { +pub struct StreamChatResponse { pub id: String, pub model: String, pub role: AnthropicChatMessageRole, - pub content: Vec, + pub content: Vec, pub stop_reason: Option, pub usage: Usage, } +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ChatResponse { + id: String, + model: String, + role: AnthropicChatMessageRole, + content: Vec, + stop_reason: Option, + usage: Usage, +} + +impl TryFrom for ChatResponse { + type Error = anyhow::Error; + + fn try_from(cr: StreamChatResponse) -> Result { + let content = cr + .content + .into_iter() + .map(AnthropicResponseContent::try_from) + .collect::, anyhow::Error>>()?; + + Ok(ChatResponse { + id: cr.id, + model: cr.model, + role: cr.role, + content, + stop_reason: cr.stop_reason, + usage: cr.usage, + }) + } +} + +// This code converts a ChatResponse to a ChatMessage, but only supports one tool call. +// It takes the first tool call from the vector of AnthropicResponseContent, +// potentially discarding others. Anthropic often returns the CoT content as a first message, +// which gets combined with the first tool call in the resulting ChatMessage. +impl TryFrom for ChatMessage { + type Error = anyhow::Error; + + fn try_from(cr: ChatResponse) -> Result { + let text_content = cr.content.iter().find_map(|item| match item.get_text() { + Some(text) => Some(text.clone()), + _ => None, + }); + + let tool_uses: Vec<&ToolUse> = cr + .content + .iter() + .filter_map(|item| match item.get_tool_use() { + Some(tool_use) => Some(tool_use), + _ => None, + }) + .collect(); + + let function_call = tool_uses + .into_iter() + .map(|fc| ChatFunctionCall::try_from(fc)) + .collect::, _>>()? + .first() + .cloned(); + + Ok(ChatMessage { + role: ChatMessageRole::Assistant, + name: None, + content: text_content, + function_call, + }) + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct CompletionResponse { pub completion: String, @@ -155,21 +332,21 @@ pub struct Error { #[derive(Serialize, Deserialize, Debug, Clone)] struct StreamMessageStart { pub r#type: String, - pub message: ChatResponse, + pub message: StreamChatResponse, } #[derive(Serialize, Deserialize, Debug, Clone)] struct StreamContentBlockStart { pub r#type: String, pub index: u64, - pub content_block: AnthropicContent, + pub content_block: AnthropicStreamContent, } #[derive(Serialize, Deserialize, Debug, Clone)] struct StreamContentBlockDelta { pub r#type: String, pub index: u64, - pub delta: AnthropicContent, + pub delta: AnthropicStreamContent, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -233,6 +410,7 @@ impl AnthropicLLM { &self, system: Option, messages: &Vec, + tools: Vec, temperature: f32, top_p: f32, stop_sequences: &Vec, @@ -256,11 +434,22 @@ impl AnthropicLLM { body["system"] = json!(system); } + if !tools.is_empty() { + body["tools"] = json!(tools); + } + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse()?); + headers.insert("X-API-Key", self.api_key.clone().unwrap().parse()?); + headers.insert("anthropic-version", "2023-06-01".parse()?); + + if !tools.is_empty() { + headers.insert("anthropic-beta", "tools-2024-04-04".parse()?); + } + let res = reqwest::Client::new() .post(self.messages_uri()?.to_string()) - .header("Content-Type", "application/json") - .header("X-API-Key", self.api_key.clone().unwrap()) - .header("anthropic-version", "2023-06-01") + .headers(headers) .json(&body) .send() .await?; @@ -272,10 +461,7 @@ impl AnthropicLLM { body.reader().read_to_end(&mut b)?; let c: &[u8] = &b; let response = match status { - reqwest::StatusCode::OK => { - let response: ChatResponse = serde_json::from_slice(c)?; - Ok(response) - } + reqwest::StatusCode::OK => Ok(serde_json::from_slice(c)?), _ => { let error: Error = serde_json::from_slice(c)?; Err(ModelError { @@ -292,10 +478,11 @@ impl AnthropicLLM { Ok(response) } - pub async fn streamed_chat_completion( + async fn streamed_chat_completion( &self, system: Option, messages: &Vec, + tools: Vec, temperature: f32, top_p: f32, stop_sequences: &Vec, @@ -304,6 +491,13 @@ impl AnthropicLLM { ) -> Result { assert!(self.api_key.is_some()); + // Streaming (stream=true) is not yet supported on tools. + if !tools.is_empty() { + return Err(anyhow!( + "Anthropic does not support chat functions in stream mode." + )); + } + let mut body = json!({ "model": self.id.clone(), "messages": messages, @@ -362,7 +556,7 @@ impl AnthropicLLM { let mut stream = client.stream(); - let mut final_response: Option = None; + let mut final_response: Option = None; 'stream: loop { match stream.try_next().await { Ok(stream_next) => match stream_next { @@ -563,7 +757,11 @@ impl AnthropicLLM { } match final_response { - Some(response) => Ok(response), + Some(response) => { + let chat_response: ChatResponse = ChatResponse::try_from(response)?; + + Ok(chat_response) + } None => Err(anyhow!("No response from Anthropic")), } } @@ -934,7 +1132,7 @@ impl LLM for AnthropicLLM { &self, messages: &Vec, functions: &Vec, - function_call: Option, + _function_call: Option, temperature: f32, top_p: Option, n: usize, @@ -952,9 +1150,6 @@ impl LLM for AnthropicLLM { "Anthropic only supports generating one sample at a time." ))?; } - if functions.len() > 0 || function_call.is_some() { - return Err(anyhow!("Anthropic does not support chat functions.")); - } if let Some(m) = max_tokens { if m == -1 { @@ -1001,6 +1196,7 @@ impl LLM for AnthropicLLM { .iter() .map(|cm| AnthropicChatMessage { content: vec![AnthropicContent { + // TODO: Support `tool_result` here. r#type: String::from("text"), text: cm .content @@ -1015,11 +1211,17 @@ impl LLM for AnthropicLLM { // merge messages of the same role + let tools = functions + .iter() + .map(AnthropicTool::try_from) + .collect::, _>>()?; + let c = match event_sender { Some(es) => { self.streamed_chat_completion( system, &messages, + tools, temperature, match top_p { Some(p) => p, @@ -1038,6 +1240,7 @@ impl LLM for AnthropicLLM { self.chat_completion( system, &messages, + tools, temperature, match top_p { Some(p) => p, @@ -1053,23 +1256,12 @@ impl LLM for AnthropicLLM { } }; - match c.content.first() { - None => Err(anyhow!("No content in response from Anthropic.")), - Some(content) => match content.r#type.as_str() { - "text" => Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::Anthropic.to_string(), - model: self.id.clone(), - completions: vec![ChatMessage { - role: ChatMessageRole::Assistant, - content: Some(content.text.clone()), - name: None, - function_call: None, - }], - }), - _ => Err(anyhow!("Anthropic returned an unexpected content type.")), - }, - } + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::Anthropic.to_string(), + model: self.id.clone(), + completions: ChatMessage::try_from(c).into_iter().collect(), + }) } }