diff --git a/connectors/src/api/admin.ts b/connectors/src/api/admin.ts index 54d0f65ab8d3..b0e3895efa09 100644 --- a/connectors/src/api/admin.ts +++ b/connectors/src/api/admin.ts @@ -21,6 +21,10 @@ const whitelistedCommands = [ command: "find-url", }, { majorCommand: "slack", command: "whitelist-bot" }, + { + majorCommand: "connectors", + command: "set-error", + }, ]; const _adminAPIHandler = async ( diff --git a/connectors/src/connectors/confluence/lib/permissions.ts b/connectors/src/connectors/confluence/lib/permissions.ts index 48ac475a6873..73c7e698ea2e 100644 --- a/connectors/src/connectors/confluence/lib/permissions.ts +++ b/connectors/src/connectors/confluence/lib/permissions.ts @@ -34,15 +34,22 @@ function isConfluenceSpaceModel( ); } +export function getConfluenceSpaceUrl( + space: ConfluenceSpace | ConfluenceSpaceType, + baseUrl: string +) { + const urlSuffix = isConfluenceSpaceModel(space) + ? space.urlSuffix + : space._links.webui; + return `${baseUrl}/wiki${urlSuffix}`; +} + export function createContentNodeFromSpace( space: ConfluenceSpace | ConfluenceSpaceType, baseUrl: string, permission: ConnectorPermission, { isExpandable }: { isExpandable: boolean } ): ContentNode { - const urlSuffix = isConfluenceSpaceModel(space) - ? space.urlSuffix - : space._links.webui; const spaceId = isConfluenceSpaceModel(space) ? space.spaceId : space.id; return { @@ -50,7 +57,7 @@ export function createContentNodeFromSpace( parentInternalId: null, type: "folder", title: space.name || "Unnamed Space", - sourceUrl: `${baseUrl}/wiki${urlSuffix}`, + sourceUrl: getConfluenceSpaceUrl(space, baseUrl), expandable: isExpandable, permission, lastUpdatedAt: null, diff --git a/connectors/src/connectors/confluence/temporal/activities.ts b/connectors/src/connectors/confluence/temporal/activities.ts index 56dffa86281d..a30625f21bf2 100644 --- a/connectors/src/connectors/confluence/temporal/activities.ts +++ b/connectors/src/connectors/confluence/temporal/activities.ts @@ -23,6 +23,7 @@ import { makePageInternalId, makeSpaceInternalId, } from "@connectors/connectors/confluence/lib/internal_ids"; +import { getConfluenceSpaceUrl } from "@connectors/connectors/confluence/lib/permissions"; import { makeConfluenceDocumentUrl } from "@connectors/connectors/confluence/temporal/workflow_ids"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { concurrentExecutor } from "@connectors/lib/async_utils"; @@ -209,13 +210,20 @@ export async function confluenceUpsertSpaceFolderActivity({ connectorId, spaceId, spaceName, + baseUrl, }: { connectorId: ModelId; spaceId: string; spaceName: string; + baseUrl: string; }) { const connector = await fetchConfluenceConnector(connectorId); + const spaceInDb = await ConfluenceSpace.findOne({ + attributes: ["urlSuffix"], + where: { connectorId, spaceId }, + }); + await upsertDataSourceFolder({ dataSourceConfig: dataSourceConfigFromConnector(connector), folderId: makeSpaceInternalId(spaceId), @@ -223,6 +231,8 @@ export async function confluenceUpsertSpaceFolderActivity({ parentId: null, title: spaceName, mimeType: MIME_TYPES.CONFLUENCE.SPACE, + sourceUrl: + spaceInDb?.urlSuffix && getConfluenceSpaceUrl(spaceInDb, baseUrl), }); } diff --git a/connectors/src/connectors/confluence/temporal/workflows.ts b/connectors/src/connectors/confluence/temporal/workflows.ts index 3074cd740915..9c9cd4b13110 100644 --- a/connectors/src/connectors/confluence/temporal/workflows.ts +++ b/connectors/src/connectors/confluence/temporal/workflows.ts @@ -148,7 +148,7 @@ export async function confluenceSpaceSyncWorkflow( const confluenceConfig = await fetchConfluenceConfigurationActivity( params.connectorId ); - const { cloudId: confluenceCloudId } = confluenceConfig; + const { cloudId: confluenceCloudId, url: baseUrl } = confluenceConfig; const spaceName = await confluenceGetSpaceNameActivity({ ...params, @@ -163,6 +163,7 @@ export async function confluenceSpaceSyncWorkflow( connectorId, spaceId, spaceName, + baseUrl, }); const allowedRootPageIds = await fetchAndUpsertRootPagesActivity({ diff --git a/connectors/src/connectors/google_drive/index.ts b/connectors/src/connectors/google_drive/index.ts index 4987259ccb79..567b031cadc8 100644 --- a/connectors/src/connectors/google_drive/index.ts +++ b/connectors/src/connectors/google_drive/index.ts @@ -38,6 +38,7 @@ import { isGoogleDriveFolder, isGoogleDriveSpreadSheetFile, } from "@connectors/connectors/google_drive/temporal/mime_types"; +import type { Sheet } from "@connectors/connectors/google_drive/temporal/spreadsheets"; import { driveObjectToDustType, getAuthObject, @@ -67,6 +68,7 @@ import { terminateAllWorkflowsForConnectorId } from "@connectors/lib/temporal"; import logger from "@connectors/logger/logger"; import { ConnectorResource } from "@connectors/resources/connector_resource"; import type { DataSourceConfig } from "@connectors/types/data_source_config.js"; +import type { GoogleDriveObjectType } from "@connectors/types/google_drive"; import { FILE_ATTRIBUTES_TO_FETCH } from "@connectors/types/google_drive"; export class GoogleDriveConnectorManager extends BaseConnectorManager { @@ -706,7 +708,7 @@ export class GoogleDriveConnectorManager extends BaseConnectorManager { type: "database", title: s.name || "", lastUpdatedAt: s.updatedAt.getTime() || null, - sourceUrl: `https://docs.google.com/spreadsheets/d/${s.driveFileId}/edit#gid=${s.driveSheetId}`, + sourceUrl: getSourceUrlForGoogleDriveSheet(s), expandable: false, permission: "read", })); @@ -976,12 +978,25 @@ async function getFoldersAsContentNodes({ ); } -function getSourceUrlForGoogleDriveFiles(f: GoogleDriveFiles): string { +export function getSourceUrlForGoogleDriveFiles( + f: GoogleDriveFiles | GoogleDriveObjectType +): string { + const driveFileId = f instanceof GoogleDriveFiles ? f.driveFileId : f.id; + if (isGoogleDriveSpreadSheetFile(f)) { - return `https://docs.google.com/spreadsheets/d/${f.driveFileId}/edit`; + return `https://docs.google.com/spreadsheets/d/${driveFileId}/edit`; } else if (isGoogleDriveFolder(f)) { - return `https://drive.google.com/drive/folders/${f.driveFileId}`; + return `https://drive.google.com/drive/folders/${driveFileId}`; } - return `https://drive.google.com/file/d/${f.driveFileId}/view`; + return `https://drive.google.com/file/d/${driveFileId}/view`; +} + +export function getSourceUrlForGoogleDriveSheet( + s: GoogleDriveSheet | Sheet +): string { + const driveFileId = + s instanceof GoogleDriveSheet ? s.driveFileId : s.spreadsheet.id; + const driveSheetId = s instanceof GoogleDriveSheet ? s.driveSheetId : s.id; + return `https://docs.google.com/spreadsheets/d/${driveFileId}/edit#gid=${driveSheetId}`; } diff --git a/connectors/src/connectors/google_drive/temporal/activities.ts b/connectors/src/connectors/google_drive/temporal/activities.ts index ab15cd7c5729..d0ee0ba5ed52 100644 --- a/connectors/src/connectors/google_drive/temporal/activities.ts +++ b/connectors/src/connectors/google_drive/temporal/activities.ts @@ -8,6 +8,7 @@ import StatsD from "hot-shots"; import PQueue from "p-queue"; import { Op } from "sequelize"; +import { getSourceUrlForGoogleDriveFiles } from "@connectors/connectors/google_drive"; import { GOOGLE_DRIVE_SHARED_WITH_ME_VIRTUAL_ID, GOOGLE_DRIVE_USER_SPACE_VIRTUAL_DRIVE_ID, @@ -516,6 +517,7 @@ export async function incrementalSync( parentId: parents[1] || null, title: driveFile.name ?? "", mimeType: MIME_TYPES.GOOGLE_DRIVE.FOLDER, + sourceUrl: getSourceUrlForGoogleDriveFiles(driveFile), }); await GoogleDriveFiles.upsert({ @@ -861,6 +863,7 @@ export async function markFolderAsVisited( parentId: parents[1] || null, title: file.name ?? "", mimeType: MIME_TYPES.GOOGLE_DRIVE.FOLDER, + sourceUrl: getSourceUrlForGoogleDriveFiles(file), }); await GoogleDriveFiles.upsert({ diff --git a/connectors/src/connectors/google_drive/temporal/mime_types.ts b/connectors/src/connectors/google_drive/temporal/mime_types.ts index b857665d1ba5..d6279eebb87a 100644 --- a/connectors/src/connectors/google_drive/temporal/mime_types.ts +++ b/connectors/src/connectors/google_drive/temporal/mime_types.ts @@ -1,5 +1,3 @@ -import type { GoogleDriveFiles } from "@connectors/lib/models/google_drive"; - export const MIME_TYPES_TO_EXPORT: { [key: string]: string } = { "application/vnd.google-apps.document": "text/plain", "application/vnd.google-apps.presentation": "text/plain", @@ -48,7 +46,7 @@ export async function getMimeTypesToSync({ return mimeTypes; } -export function isGoogleDriveFolder(file: GoogleDriveFiles) { +export function isGoogleDriveFolder(file: { mimeType: string }) { return file.mimeType === "application/vnd.google-apps.folder"; } diff --git a/connectors/src/connectors/google_drive/temporal/spreadsheets.ts b/connectors/src/connectors/google_drive/temporal/spreadsheets.ts index 7b1e89cbcef9..d070bc9c0d2f 100644 --- a/connectors/src/connectors/google_drive/temporal/spreadsheets.ts +++ b/connectors/src/connectors/google_drive/temporal/spreadsheets.ts @@ -11,6 +11,7 @@ import type { sheets_v4 } from "googleapis"; import { google } from "googleapis"; import type { OAuth2Client } from "googleapis-common"; +import { getSourceUrlForGoogleDriveSheet } from "@connectors/connectors/google_drive"; import { getFileParentsMemoized } from "@connectors/connectors/google_drive/lib/hierarchy"; import { getInternalId } from "@connectors/connectors/google_drive/temporal/utils"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; @@ -30,7 +31,7 @@ import type { GoogleDriveObjectType } from "@connectors/types/google_drive"; const MAXIMUM_NUMBER_OF_GSHEET_ROWS = 50000; -type Sheet = sheets_v4.Schema$ValueRange & { +export type Sheet = sheets_v4.Schema$ValueRange & { id: number; spreadsheet: { id: string; @@ -87,6 +88,7 @@ async function upsertGdriveTable( useAppForHeaderDetection: true, title: `${spreadsheet.title} - ${title}`, mimeType: "application/vnd.google-apps.spreadsheet", + sourceUrl: getSourceUrlForGoogleDriveSheet(sheet), }); logger.info(loggerArgs, "[Spreadsheet] Table upserted."); diff --git a/connectors/src/connectors/intercom/temporal/sync_help_center.ts b/connectors/src/connectors/intercom/temporal/sync_help_center.ts index 9655b5fee61b..be3021fb8298 100644 --- a/connectors/src/connectors/intercom/temporal/sync_help_center.ts +++ b/connectors/src/connectors/intercom/temporal/sync_help_center.ts @@ -230,6 +230,7 @@ export async function upsertCollectionWithChildren({ parents: collectionParents, parentId: collectionParents[1], mimeType: MIME_TYPES.INTERCOM.COLLECTION, + sourceUrl: collection.url || fallbackCollectionUrl, }); // Then we call ourself recursively on the children collections diff --git a/connectors/src/connectors/notion/temporal/activities.ts b/connectors/src/connectors/notion/temporal/activities.ts index 719f3dd0829f..b58909a405b2 100644 --- a/connectors/src/connectors/notion/temporal/activities.ts +++ b/connectors/src/connectors/notion/temporal/activities.ts @@ -1830,6 +1830,9 @@ export async function renderAndUpsertPageFromCache({ parentId: parents[1] || null, title: parentDb.title ?? "Untitled Notion Database", mimeType: MIME_TYPES.NOTION.DATABASE, + sourceUrl: + parentDb.notionUrl ?? + `https://www.notion.so/${parentDb.notionDatabaseId.replace(/-/g, "")}`, }), localLogger ); @@ -2551,6 +2554,9 @@ export async function upsertDatabaseStructuredDataFromCache({ parentId: parentIds[1] || null, title: dbModel.title ?? "Untitled Notion Database", mimeType: MIME_TYPES.NOTION.DATABASE, + sourceUrl: + dbModel.notionUrl ?? + `https://www.notion.so/${dbModel.notionDatabaseId.replace(/-/g, "")}`, }), localLogger ); diff --git a/connectors/src/connectors/slack/lib/channels.ts b/connectors/src/connectors/slack/lib/channels.ts index cacb3ba0db1f..941bdbff8792 100644 --- a/connectors/src/connectors/slack/lib/channels.ts +++ b/connectors/src/connectors/slack/lib/channels.ts @@ -18,6 +18,7 @@ export type SlackChannelType = { slackId: string; permission: ConnectorPermission; agentConfigurationId: string | null; + private: boolean; }; export async function updateSlackChannelInConnectorsDb({ @@ -60,6 +61,7 @@ export async function updateSlackChannelInConnectorsDb({ slackId: channel.slackChannelId, permission: channel.permission, agentConfigurationId: channel.agentConfigurationId, + private: channel.private, }; } diff --git a/connectors/src/connectors/slack/temporal/activities.ts b/connectors/src/connectors/slack/temporal/activities.ts index 02428ac3b9fc..ca53eeb9eebd 100644 --- a/connectors/src/connectors/slack/temporal/activities.ts +++ b/connectors/src/connectors/slack/temporal/activities.ts @@ -257,6 +257,7 @@ export async function syncChannel( parents: [slackChannelInternalIdFromSlackChannelId(channelId)], mimeType: MIME_TYPES.SLACK.CHANNEL, sourceUrl: getSlackChannelSourceUrl(channelId, slackConfiguration), + providerVisibility: channel.private ? "private" : "public", }); } diff --git a/connectors/src/connectors/webcrawler/temporal/activities.ts b/connectors/src/connectors/webcrawler/temporal/activities.ts index 83feabe76167..764e8e30a61f 100644 --- a/connectors/src/connectors/webcrawler/temporal/activities.ts +++ b/connectors/src/connectors/webcrawler/temporal/activities.ts @@ -293,6 +293,7 @@ export async function crawlWebsiteByConnectorId(connectorId: ModelId) { parentId: parents[1] || null, title: folder, mimeType: MIME_TYPES.WEBCRAWLER.FOLDER, + sourceUrl: webCrawlerFolder.url, }); createdFolders.add(folder); diff --git a/connectors/src/lib/data_sources.ts b/connectors/src/lib/data_sources.ts index 0fe47cfe3a10..3891a0905a6e 100644 --- a/connectors/src/lib/data_sources.ts +++ b/connectors/src/lib/data_sources.ts @@ -9,6 +9,7 @@ import type { CoreAPIFolder, CoreAPITable, PostDataSourceDocumentRequestBody, + ProviderVisibility, } from "@dust-tt/types"; import { isValidDate, @@ -1244,6 +1245,7 @@ export async function _upsertDataSourceFolder({ title, mimeType, sourceUrl, + providerVisibility, }: { dataSourceConfig: DataSourceConfig; folderId: string; @@ -1253,6 +1255,7 @@ export async function _upsertDataSourceFolder({ title: string; mimeType: string; sourceUrl?: string; + providerVisibility?: ProviderVisibility; }) { const now = new Date(); @@ -1265,6 +1268,7 @@ export async function _upsertDataSourceFolder({ parents, mimeType, sourceUrl: sourceUrl ?? null, + providerVisibility: providerVisibility || null, }); if (r.isErr()) { diff --git a/connectors/src/types/google_drive.ts b/connectors/src/types/google_drive.ts index 8d174c5d2e44..f2b0ac58a855 100644 --- a/connectors/src/types/google_drive.ts +++ b/connectors/src/types/google_drive.ts @@ -17,6 +17,7 @@ export type GoogleDriveObjectType = { driveId: string; isInSharedDrive: boolean; }; + export type GoogleDriveFolderType = { id: string; name: string; diff --git a/core/src/lib.rs b/core/src/lib.rs index 945312d0f1e8..f3cfaffbfe71 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -56,6 +56,7 @@ pub mod providers { pub mod anthropic; pub mod deepseek; pub mod google_ai_studio; + pub mod openai_compatible_helpers; pub mod togetherai; } pub mod http { diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index 29d8145b061d..7265037d0dc9 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -1,14 +1,11 @@ -use crate::providers::chat_messages::AssistantChatMessage; use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::ChatFunction; use crate::providers::llm::Tokens; use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::logprobs_from_choices; -use crate::providers::openai::{ - chat_completion, completion, embed, streamed_chat_completion, streamed_completion, - to_openai_messages, OpenAILLM, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::openai::completion; +use crate::providers::openai::embed; +use crate::providers::openai::streamed_completion; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async}; use crate::providers::tiktoken::tiktoken::{ @@ -25,10 +22,13 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io::prelude::*; -use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai::OpenAILLM; +use super::openai_compatible_helpers::openai_compatible_chat_completion; +use super::openai_compatible_helpers::TransformSystemMessages; + #[derive(Serialize, Deserialize, Debug, Clone)] struct AzureOpenAIScaleSettings { scale_type: String, @@ -436,7 +436,7 @@ impl LLM for AzureOpenAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -444,132 +444,30 @@ impl LLM for AzureOpenAILLM { extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let (openai_user, response_format, reasoning_effort) = match &extras { - None => (None, None, None), - Some(v) => ( - match v.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - match v.get("response_format") { - Some(Value::String(f)) => Some(f.to_string()), - _ => None, - }, - match v.get("reasoning_effort") { - Some(Value::String(r)) => Some(r.to_string()), - _ => None, - }, - ), - }; - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - let openai_messages = to_openai_messages(messages, &self.model_id.clone().unwrap())?; - - let (c, request_id) = match event_sender { - Some(_) => { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - None, - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - event_sender, - ) - .await? - } - None => { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - None, - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - ) - .await? - } - }; - - // println!("COMPLETION: {:?}", c); - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::AzureOpenAI.to_string(), - model: self.model_id.clone().unwrap(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.model_id.clone().unwrap(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + extras, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "AzureOpenAI".to_string(), + false, // don't squash text contents + ) + .await } } diff --git a/core/src/providers/deepseek.rs b/core/src/providers/deepseek.rs index d3c49e4ea7cd..63dc865857c1 100644 --- a/core/src/providers/deepseek.rs +++ b/core/src/providers/deepseek.rs @@ -1,12 +1,9 @@ -use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use std::sync::Arc; + +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::Embedder; use crate::providers::llm::ChatFunction; -use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::{ - chat_completion, logprobs_from_choices, streamed_chat_completion, to_openai_messages, - OpenAIChatMessage, OpenAIChatMessageContent, OpenAIContentBlock, OpenAITextContent, - OpenAITextContentType, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLM}; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; @@ -18,10 +15,12 @@ use async_trait::async_trait; use hyper::Uri; use parking_lot::RwLock; use serde_json::Value; -use std::str::FromStr; -use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, TransformSystemMessages, +}; + pub struct DeepseekLLM { id: String, api_key: Option, @@ -115,7 +114,7 @@ impl LLM for DeepseekLLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -123,138 +122,30 @@ impl LLM for DeepseekLLM { _extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - // Deepseek doesn't work with the new chat message content format. - // We have to modify the messages contents to use the "String" format. - let openai_messages = to_openai_messages(messages, &self.id)? - .into_iter() - .filter_map(|m| match m.content { - None => Some(m), - Some(OpenAIChatMessageContent::String(_)) => Some(m), - Some(OpenAIChatMessageContent::Structured(contents)) => { - // Find the first text content, and use it to make a string content. - let content = contents.into_iter().find_map(|c| match c { - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text, - .. - }) => Some(OpenAIChatMessageContent::String(text)), - _ => None, - }); - - Some(OpenAIChatMessage { - role: m.role, - name: m.name, - tool_call_id: m.tool_call_id, - tool_calls: m.tool_calls, - content, - }) - } - }) - .collect::>(); - - let is_streaming = event_sender.is_some(); - - let (c, request_id) = if is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - ) - .await? - }; - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::Deepseek.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + None, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "DeepSeek".to_string(), + false, // don't squash text contents + ) + .await } } diff --git a/core/src/providers/google_ai_studio.rs b/core/src/providers/google_ai_studio.rs index fbbaffb027c5..1519215840ff 100644 --- a/core/src/providers/google_ai_studio.rs +++ b/core/src/providers/google_ai_studio.rs @@ -1,309 +1,26 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use eventsource_client as es; -use eventsource_client::Client as ESClient; -use futures::TryStreamExt; -use hyper::StatusCode; -use parking_lot::{Mutex, RwLock}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; + +use http::Uri; +use parking_lot::RwLock; +use serde_json::Value; use std::sync::Arc; -use std::time::Duration; + use tokio::sync::mpsc::UnboundedSender; -use crate::{ - providers::{ - chat_messages::AssistantChatMessage, - llm::Tokens, - provider::{ModelError, ModelErrorRetryOptions}, - }, - run::Credentials, - utils, -}; +use crate::{run::Credentials, utils}; use super::{ - chat_messages::{ChatMessage, ContentBlock, FunctionChatMessage, MixedContent}, + chat_messages::ChatMessage, embedder::Embedder, - llm::{ - ChatFunction, ChatFunctionCall, ChatMessageRole, LLMChatGeneration, LLMGeneration, - LLMTokenUsage, LLM, - }, + llm::{ChatFunction, LLMChatGeneration, LLMGeneration, LLM}, + openai_compatible_helpers::{openai_compatible_chat_completion, TransformSystemMessages}, provider::{Provider, ProviderID}, tiktoken::tiktoken::{ batch_tokenize_async, cl100k_base_singleton, decode_async, encode_async, CoreBPE, }, }; -// Disabled for now as it requires using a "tools" API which we don't support yet. -pub const USE_FUNCTION_CALLING: bool = false; - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct UsageMetadata { - prompt_token_count: Option, - candidates_token_count: Option, - total_token_count: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAiStudioFunctionResponseContent { - name: String, - content: String, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionResponse { - name: String, - response: GoogleAiStudioFunctionResponseContent, -} - -impl TryFrom<&FunctionChatMessage> for GoogleAIStudioFunctionResponse { - type Error = anyhow::Error; - - fn try_from(m: &FunctionChatMessage) -> Result { - let name = m.name.clone().unwrap_or_default(); - Ok(GoogleAIStudioFunctionResponse { - name: name.clone(), - response: GoogleAiStudioFunctionResponseContent { - name, - content: m.content.clone(), - }, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionCall { - name: String, - args: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioFunctionDeclaration { - name: String, - description: String, - parameters: Option, -} - -impl TryFrom<&ChatFunction> for GoogleAIStudioFunctionDeclaration { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunction) -> Result { - Ok(GoogleAIStudioFunctionDeclaration { - name: f.name.clone(), - description: f.description.clone().unwrap_or_else(|| String::from("")), - parameters: match f.parameters.clone() { - // The API rejects empty 'properties'. If 'properties' is empty, return None. - // Otherwise, return the object wrapped in Some. - Some(serde_json::Value::Object(obj)) => { - if obj.get("properties").map_or(false, |props| { - props.as_object().map_or(false, |p| p.is_empty()) - }) { - None - } else { - Some(serde_json::Value::Object(obj)) - } - } - p => p, - }, - }) - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "UPPERCASE")] -pub enum GoogleAIStudioTooConfigMode { - Auto, - Any, - None, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode, - allowed_function_names: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Part { - text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_response: Option, -} - -impl TryFrom<&ChatFunctionCall> for GoogleAIStudioFunctionCall { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunctionCall) -> Result { - let args = match serde_json::from_str(f.arguments.as_str()) { - Ok(v) => v, - Err(_) => Err(anyhow!( - "GoogleAISudio function call arguments must be valid JSON" - ))?, - }; - Ok(GoogleAIStudioFunctionCall { - name: f.name.clone(), - args, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Content { - role: Option, - parts: Option>, -} - -impl TryFrom<&ChatMessage> for Content { - type Error = anyhow::Error; - - fn try_from(cm: &ChatMessage) -> Result { - match cm { - ChatMessage::Assistant(assistant_msg) => { - let parts = match assistant_msg.function_calls { - Some(ref fcs) => fcs - .iter() - .map(|fc| { - Ok(Part { - text: assistant_msg.content.clone(), - function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), - function_response: None, - }) - }) - .collect::, anyhow::Error>>()?, - None => { - if let Some(ref fc) = assistant_msg.function_call { - vec![Part { - text: assistant_msg.content.clone(), - function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), - function_response: None, - }] - } else { - vec![Part { - text: assistant_msg.content.clone(), - function_call: None, - function_response: None, - }] - } - } - }; - - Ok(Content { - role: Some(String::from("model")), - parts: Some(parts), - }) - } - ChatMessage::Function(function_msg) => Ok(Content { - role: Some(String::from("function")), - parts: Some(vec![Part { - text: None, - function_call: None, - function_response: GoogleAIStudioFunctionResponse::try_from(function_msg).ok(), - }]), - }), - ChatMessage::User(user_msg) => { - let text = match &user_msg.content { - ContentBlock::Mixed(m) => { - let result = m.iter().enumerate().try_fold( - String::new(), - |mut acc, (i, content)| { - match content { - MixedContent::ImageContent(_) => Err(anyhow!( - "Vision is not supported for Google AI Studio." - )), - MixedContent::TextContent(tc) => { - acc.push_str(&tc.text.trim()); - if i != m.len() - 1 { - // Add newline if it's not the last item. - acc.push('\n'); - } - Ok(acc) - } - } - }, - ); - - match result { - Ok(text) if !text.is_empty() => Ok(text), - Ok(_) => Err(anyhow!("Text is required.")), // Empty string. - Err(e) => Err(e), - } - } - ContentBlock::Text(t) => Ok(t.clone()), - }?; - - Ok(Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - text: Some(text), - function_call: None, - function_response: None, - }]), - }) - } - ChatMessage::System(system_msg) => Ok(Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - // System is passed as a Content. We transform it here but it will be removed - // from the list of messages and passed as separate argument to the API. - text: Some(system_msg.content.clone()), - function_call: None, - function_response: None, - }]), - }), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Candidate { - content: Option, - finish_reason: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Completion { - candidates: Option>, - usage_metadata: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct InnerError { - pub message: String, - pub code: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct GoogleAIStudioError { - pub error: InnerError, -} - -impl GoogleAIStudioError { - pub fn message(&self) -> String { - format!("GoogleAIStudio: {}", self.error.message) - } - - pub fn retryable(&self) -> bool { - return false; - } - - pub fn retryable_streamed(&self, status: StatusCode) -> bool { - if status == StatusCode::TOO_MANY_REQUESTS { - return true; - } - if status.is_server_error() { - return true; - } - return false; - } -} - pub struct GoogleAiStudioProvider {} impl GoogleAiStudioProvider { @@ -349,11 +66,8 @@ impl GoogleAiStudioLLM { Self { id, api_key: None } } - fn model_endpoint(&self) -> String { - format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse", - self.id - ) + fn model_endpoint(&self) -> Uri { + Uri::from_static("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions") } fn tokenizer(&self) -> Arc> { @@ -397,104 +111,19 @@ impl LLM for GoogleAiStudioLLM { async fn generate( &self, - prompt: &str, - mut max_tokens: Option, - temperature: f32, - n: usize, - stop: &Vec, - presence_penalty: Option, - frequency_penalty: Option, - top_p: Option, - top_logprobs: Option, + _prompt: &str, + _max_tokens: Option, + _temperature: f32, + _n: usize, + _stop: &Vec, + _presence_penalty: Option, + _frequency_penalty: Option, + _top_p: Option, + _top_logprobs: Option, _extras: Option, - event_sender: Option>, + _event_sender: Option>, ) -> Result { - assert!(n == 1); - - let api_key = match &self.api_key { - Some(k) => k.to_string(), - None => Err(anyhow!("API key not found"))?, - }; - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - if top_logprobs.is_some() { - Err(anyhow!("Top logprobs not supported by GoogleAIStudio"))?; - } - - if let Some(m) = max_tokens { - if m == -1 { - let tokens = self.encode(prompt).await?; - max_tokens = Some((self.context_size() - tokens.len()) as i32); - } - } - - let uri = self.model_endpoint(); - - let c = streamed_chat_completion( - uri, - api_key, - &vec![Content { - role: Some(String::from("user")), - parts: Some(vec![Part { - text: Some(String::from(prompt)), - function_call: None, - function_response: None, - }]), - }], - vec![], - None, - None, - temperature, - stop, - max_tokens, - match top_p { - Some(t) => t, - None => 1.0, - }, - None, - event_sender, - false, - ) - .await?; - - Ok(LLMGeneration { - created: utils::now(), - provider: ProviderID::GoogleAiStudio.to_string(), - model: self.id().clone(), - completions: vec![Tokens { - // Get candidates?.[0]?.content?.parts?.[0]?.text ?? "". - text: c - .candidates - .as_ref() - .and_then(|c| c.first()) - .and_then(|c| c.content.as_ref()) - .and_then(|c| c.parts.as_ref()) - .and_then(|p| p.first()) - .and_then(|p| p.text.as_ref()) - .map(|t| t.to_string()) - .unwrap_or_else(|| String::from("")), - - tokens: Some(vec![]), - logprobs: Some(vec![]), - top_logprobs: None, - }], - prompt: Tokens { - text: prompt.to_string(), - tokens: Some(vec![]), - logprobs: Some(vec![]), - top_logprobs: None, - }, - usage: c.usage_metadata.map(|c| LLMTokenUsage { - prompt_tokens: c.prompt_token_count.unwrap_or(0) as u64, - completion_tokens: c.candidates_token_count.unwrap_or(0) as u64, - }), - provider_request_id: None, - }) + unimplemented!() } async fn chat( @@ -506,489 +135,37 @@ impl LLM for GoogleAiStudioLLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, - presence_penalty: Option, - frequency_penalty: Option, - _logprobs: Option, - _top_logprobs: Option, + max_tokens: Option, + _presence_penalty: Option, + _frequency_penalty: Option, + logprobs: Option, + top_logprobs: Option, _extras: Option, event_sender: Option>, ) -> Result { - assert!(n == 1); - - let api_key = match &self.api_key { - Some(k) => k.to_string(), - None => Err(anyhow!("API key not found"))?, - }; - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - if frequency_penalty.is_some() { - Err(anyhow!("Frequency penalty not supported by GoogleAIStudio"))?; - } - if presence_penalty.is_some() { - Err(anyhow!("Presence penalty not supported by GoogleAIStudio"))?; - } - - let uri = self.model_endpoint(); - - // Remove system message if first. - let system = match messages.get(0) { - Some(cm) => match cm { - ChatMessage::System(_) => Some(Content::try_from(cm)?), - _ => None, - }, - None => None, - }; - - let messages = messages - .iter() - .skip(match system.as_ref() { - Some(_) => 1, - None => 0, - }) - .map(|cm| Content::try_from(cm)) - .collect::>>()?; - - // TODO: backward comp for non alternated messages - - let tools = functions - .iter() - .map(GoogleAIStudioFunctionDeclaration::try_from) - .collect::, _>>()?; - - let tool_config = match function_call { - Some(fc) => Some(match fc.as_str() { - "auto" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Auto, - allowed_function_names: None, - }, - "none" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::None, - allowed_function_names: None, - }, - "any" => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Any, - allowed_function_names: None, - }, - _ => GoogleAIStudioFunctionCallingConfig { - mode: GoogleAIStudioTooConfigMode::Any, - allowed_function_names: Some(vec![fc.clone()]), - }, - }), - None => None, - }; - - let c = streamed_chat_completion( - uri, - api_key, - &messages, - tools, - tool_config, - system, + openai_compatible_chat_completion( + self.model_endpoint(), + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, temperature, + top_p, + n, stop, max_tokens, - match top_p { - Some(t) => t, - None => 1.0, - }, + None, + None, + logprobs, + top_logprobs, None, event_sender, - false, - ) - .await?; - - let mut content: Option = None; - let mut function_calls: Vec = vec![]; - - // Get candidates?.[0]?.content?.parts. - if let Some(parts) = c - .candidates - .as_ref() - .and_then(|c| c.first()) - .and_then(|c| c.content.as_ref()) - .and_then(|c| c.parts.as_ref()) - { - for p in parts.iter() { - // If the part has text, either append it to the content if we already have some - // or set it as the content. - if let Some(t) = p.text.as_ref() { - content = content.map(|c| c + t).or_else(|| Some(t.clone())); - } - - // If the part has a function call, add it to the list of function calls. - if let Some(fc) = p.function_call.as_ref() { - function_calls.push(ChatFunctionCall { - id: format!("fc_{}", utils::new_id()[0..9].to_string()), - name: fc.name.clone(), - arguments: match fc.args { - Some(ref args) => serde_json::to_string(args)?, - None => String::from("{}"), - }, - }); - } - } - } - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::GoogleAiStudio.to_string(), - model: self.id().clone(), - completions: vec![AssistantChatMessage { - name: None, - function_call: match function_calls.first() { - Some(fc) => Some(fc.clone()), - None => None, - }, - function_calls: match function_calls.len() { - 0 => None, - _ => Some(function_calls), - }, - role: ChatMessageRole::Assistant, - content, - }], - usage: c.usage_metadata.map(|c| LLMTokenUsage { - prompt_tokens: c.prompt_token_count.unwrap_or(0) as u64, - completion_tokens: c.candidates_token_count.unwrap_or(0) as u64, - }), - provider_request_id: None, - logprobs: None, - }) - } -} - -pub async fn streamed_chat_completion( - uri: String, - api_key: String, - messages: &Vec, - tools: Vec, - tool_config: Option, - system_instruction: Option, - temperature: f32, - stop: &Vec, - max_tokens: Option, - top_p: f32, - top_k: Option, - event_sender: Option>, - use_header_auth: bool, -) -> Result { - let url = match use_header_auth { - true => uri.to_string(), - false => format!("{}&key={}", uri, api_key), - }; - - let mut builder = match es::ClientBuilder::for_url(url.as_str()) { - Ok(builder) => builder, - Err(e) => { - return Err(anyhow!( - "Error creating GoogleAIStudio streaming client: {:?}", - e - )) - } - }; - - if use_header_auth { - builder = match builder.method(String::from("POST")).header( - "Authorization", - format!("Bearer {}", api_key.clone()).as_str(), - ) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to GoogleAIStudio")), - }; - } - - builder = match builder.header("Content-Type", "application/json") { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to GoogleAIStudio")), - }; - - let mut body = json!({ - "contents": json!(messages), - "generation_config": { - "temperature": temperature, - "topP": top_p, - "topK": top_k, - "maxOutputTokens": max_tokens, - "stopSequences": match stop.len() { - 0 => None, - _ => Some(stop), - }, - }, - "safety_settings": [ - { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH" }, - { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH" } - ] - }); - - if tools.len() > 0 { - body["tools"] = json!(vec![json!({ - "functionDeclarations": tools - })]); - } - - if tool_config.is_some() { - body["toolConfig"] = json!({ - "functionCallingConfig": tool_config - }); - } - - if system_instruction.is_some() { - body["systemInstruction"] = json!(system_instruction); - } - - let client = builder - .body(body.to_string()) - .method("POST".to_string()) - .reconnect( - es::ReconnectOptions::reconnect(true) - .retry_initial(false) - .delay(Duration::from_secs(1)) - .backoff_factor(2) - .delay_max(Duration::from_secs(8)) - .build(), + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "GoogleAIStudio".to_string(), + false, // don't squash text contents ) - .build(); - - let mut stream = client.stream(); - - let completions: Arc>> = Arc::new(Mutex::new(Vec::new())); - - 'stream: loop { - match stream.try_next().await { - Ok(e) => match e { - Some(es::SSE::Connected(_)) => { - // GoogleAISudio does not return a request id in headers. - // Nothing to do. - } - Some(es::SSE::Comment(_)) => { - println!("UNEXPECTED COMMENT"); - } - Some(es::SSE::Event(e)) => { - let completion: Completion = serde_json::from_str(e.data.as_str())?; - let completion_candidates = completion.candidates.clone().unwrap_or_default(); - - match completion_candidates.len() { - 0 => { - break 'stream; - } - 1 => (), - n => { - Err(anyhow!("Unexpected number of candidates: {}", n))?; - } - }; - - if let (Some(parts), Some(sender)) = ( - completion_candidates[0] - .content - .as_ref() - .and_then(|c| c.parts.as_ref()), - event_sender.as_ref(), - ) { - parts.iter().for_each(|p| { - match p.text { - Some(ref t) => { - if t.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": t, - } - })); - } - } - None => (), - } - - match p.function_call { - Some(ref f) => { - let _ = sender.send(json!({ - "type": "function_call", - "content": { - "name": f.name, - } - })); - } - None => (), - } - }); - } - - completions.lock().push(completion); - } - None => { - break 'stream; - } - }, - Err(e) => { - match e { - // Nothing to do, go direclty to break stream. - es::Error::Eof => (), - es::Error::UnexpectedResponse(r) => { - let status = StatusCode::from_u16(r.status())?; - // GogoleAIStudio currently has no request id in headers. - // let headers = r.headers()?; - // let request_id = match headers.get("request-id") { - // Some(v) => Some(v.to_string()), - // None => None, - // }; - let b = r.body_bytes().await?; - - let error: Result = serde_json::from_slice(&b); - match error { - Ok(error) => { - match error.retryable_streamed(status) { - true => Err(ModelError { - request_id: None, - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id: None, - message: error.message(), - retryable: None, - }), - } - }?, - Err(_) => Err(anyhow!( - "Error streaming tokens from GoogleAIStudio: status={} data={}", - status, - String::from_utf8_lossy(&b) - ))?, - } - } - _ => { - Err(anyhow!( - "Error streaming tokens from GoogleAIStudio: {:?}", - e - ))?; - break 'stream; - } - } - break 'stream; - } - } + .await } - - let completions_lock = completions.lock(); - - // Sometimes (usually when last message is Assistant), the AI decides not to respond. - if completions_lock.len() == 0 { - return Ok(Completion { - candidates: None, - usage_metadata: None, - }); - } - - let mut usage_metadata: Option = None; - - let mut full_candidate = Candidate { - content: Some(Content { - role: Some(String::from("MODEL")), - parts: Some(vec![]), - }), - finish_reason: None, - }; - - let mut text_parts: Option = None; - let mut function_call_parts: Vec = vec![]; - - for c in completions_lock.iter() { - match &c.usage_metadata { - None => (), - Some(usage) => { - usage_metadata = Some(usage.clone()); - } - } - - // Check that we don't have more than one candidate. - match c - .candidates - .as_ref() - .map(|candidates| candidates.len()) - .unwrap_or_default() - { - 0 => (), - 1 => (), - n => Err(anyhow!("Unexpected number of candidates >1: {}", n))?, - } - - if let Some(candidate) = c.candidates.as_ref().map(|c| c.first()).flatten() { - // Validate that the role (if any) is MODEL. - - if let Some(c) = candidate.content.as_ref() { - match &c.role { - Some(r) => match r.to_uppercase().as_str() { - "MODEL" => (), - r => Err(anyhow!("Unexpected role in completion: {}", r))?, - }, - None => (), - } - } - - if let Some(r) = candidate.finish_reason.as_ref() { - full_candidate.finish_reason = Some(r.clone()); - } - - if let Some(parts) = candidate.content.as_ref().and_then(|c| c.parts.as_ref()) { - for p in parts.iter() { - if let Some(t) = p.text.as_ref() { - match text_parts.as_mut() { - Some(tp) => { - tp.text = Some(tp.text.clone().unwrap_or_default() + t.as_str()); - } - None => { - text_parts = Some(p.clone()); - } - } - } - - if p.function_call.is_some() { - function_call_parts.push(p.clone()); - } - - if p.function_response.is_some() { - Err(anyhow!("Unexpected function response part in completion"))?; - } - } - } - } - } - - match full_candidate - .content - .as_mut() - .and_then(|c| c.parts.as_mut()) - { - Some(parts) => { - if let Some(tp) = text_parts { - parts.push(tp); - } - parts.extend(function_call_parts); - } - // This should never happen since we define the `full_candidate` above. - None => unreachable!(), - } - - Ok(Completion { - candidates: Some(vec![full_candidate]), - usage_metadata, - }) } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index e839c727a77f..3737453bba9a 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1,12 +1,8 @@ -use crate::providers::chat_messages::{ - AssistantChatMessage, ChatMessage, ContentBlock, MixedContent, -}; +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::{Embedder, EmbedderVector}; -use crate::providers::llm::{ChatFunction, ChatFunctionCall}; -use crate::providers::llm::{ - ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM, -}; -use crate::providers::llm::{LLMChatLogprob, Tokens}; +use crate::providers::llm::ChatFunction; +use crate::providers::llm::Tokens; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{ batch_tokenize_async, cl100k_base_singleton, o200k_base_singleton, p50k_base_singleton, @@ -15,7 +11,6 @@ use crate::providers::tiktoken::tiktoken::{ use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; use crate::run::Credentials; use crate::utils; -use crate::utils::ParseError; use anyhow::{anyhow, Result}; use async_trait::async_trait; use eventsource_client as es; @@ -30,13 +25,14 @@ use serde_json::json; use serde_json::Value; use std::collections::HashMap; use std::io::prelude::*; -use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; use tokio::time::timeout; -use super::llm::TopLogprob; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, OpenAIError, TransformSystemMessages, +}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Usage { @@ -82,526 +78,6 @@ pub struct Completion { pub usage: Option, } -/// -/// Tools implementation types. -/// - -// Input types. - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIToolType { - Function, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIFunction { - name: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIFunctionCall { - r#type: OpenAIToolType, - function: OpenAIFunction, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIToolControl { - Auto, - Required, - None, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIToolChoice { - OpenAIToolControl(OpenAIToolControl), - OpenAIFunctionCall(OpenAIFunctionCall), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolFunction { - pub name: String, - pub description: Option, - pub parameters: Option, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAITool { - pub r#type: OpenAIToolType, - pub function: OpenAIToolFunction, -} - -impl FromStr for OpenAIToolChoice { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match s { - "auto" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::Auto)), - "any" => Ok(OpenAIToolChoice::OpenAIToolControl( - OpenAIToolControl::Required, - )), - "none" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::None)), - _ => { - let function = OpenAIFunctionCall { - r#type: OpenAIToolType::Function, - function: OpenAIFunction { - name: s.to_string(), - }, - }; - Ok(OpenAIToolChoice::OpenAIFunctionCall(function)) - } - } - } -} - -// Outputs types. - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolCallFunction { - name: String, - arguments: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIToolCall { - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - r#type: OpenAIToolType, - pub function: OpenAIToolCallFunction, -} - -impl TryFrom<&ChatFunctionCall> for OpenAIToolCall { - type Error = anyhow::Error; - - fn try_from(cf: &ChatFunctionCall) -> Result { - Ok(OpenAIToolCall { - id: Some(cf.id.clone()), - r#type: OpenAIToolType::Function, - function: OpenAIToolCallFunction { - name: cf.name.clone(), - arguments: cf.arguments.clone(), - }, - }) - } -} - -impl TryFrom<&OpenAIToolCall> for ChatFunctionCall { - type Error = anyhow::Error; - - fn try_from(tc: &OpenAIToolCall) -> Result { - let id = tc - .id - .as_ref() - .ok_or_else(|| anyhow!("Missing tool call id."))?; - - Ok(ChatFunctionCall { - id: id.clone(), - name: tc.function.name.clone(), - arguments: tc.function.arguments.clone(), - }) - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAIChatMessageRole { - Assistant, - Function, - System, - Developer, - Tool, - User, -} - -impl From<&ChatMessageRole> for OpenAIChatMessageRole { - fn from(role: &ChatMessageRole) -> Self { - match role { - ChatMessageRole::Assistant => OpenAIChatMessageRole::Assistant, - ChatMessageRole::Function => OpenAIChatMessageRole::Function, - ChatMessageRole::System => OpenAIChatMessageRole::System, - ChatMessageRole::User => OpenAIChatMessageRole::User, - } - } -} - -impl FromStr for OpenAIChatMessageRole { - type Err = ParseError; - fn from_str(s: &str) -> Result { - match s { - "system" => Ok(OpenAIChatMessageRole::System), - "user" => Ok(OpenAIChatMessageRole::User), - "assistant" => Ok(OpenAIChatMessageRole::Assistant), - "function" => Ok(OpenAIChatMessageRole::Tool), - _ => Err(ParseError::with_message("Unknown OpenAIChatMessageRole"))?, - } - } -} - -impl From for ChatMessageRole { - fn from(value: OpenAIChatMessageRole) -> Self { - match value { - OpenAIChatMessageRole::Assistant => ChatMessageRole::Assistant, - OpenAIChatMessageRole::Function => ChatMessageRole::Function, - OpenAIChatMessageRole::System => ChatMessageRole::System, - OpenAIChatMessageRole::Developer => ChatMessageRole::System, - OpenAIChatMessageRole::Tool => ChatMessageRole::Function, - OpenAIChatMessageRole::User => ChatMessageRole::User, - } - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "snake_case")] -pub enum OpenAITextContentType { - Text, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAITextContent { - #[serde(rename = "type")] - pub r#type: OpenAITextContentType, - pub text: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIImageUrlContent { - pub url: String, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(rename_all = "snake_case")] -pub enum OpenAIImageContentType { - ImageUrl, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIImageContent { - pub r#type: OpenAIImageContentType, - pub image_url: OpenAIImageUrlContent, -} - -// Define an enum for mixed content -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIContentBlock { - TextContent(OpenAITextContent), - ImageContent(OpenAIImageContent), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(untagged)] -pub enum OpenAIChatMessageContent { - Structured(Vec), - String(String), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIChatMessage { - pub role: OpenAIChatMessageRole, - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAICompletionChatMessage { - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - pub role: OpenAIChatMessageRole, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAITopLogprob { - pub token: String, - pub logprob: f32, - pub bytes: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoiceLogprob { - pub token: String, - pub logprob: f32, - pub bytes: Option>, - pub top_logprobs: Vec, -} - -impl From for TopLogprob { - fn from(top_logprob: OpenAITopLogprob) -> Self { - TopLogprob { - token: top_logprob.token, - logprob: top_logprob.logprob, - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoiceLogprobs { - pub content: Vec, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatChoice { - pub message: OpenAICompletionChatMessage, - pub index: usize, - pub finish_reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIChatCompletion { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, -} - -// This code performs a type conversion with information loss when converting to ChatFunctionCall. -// It only supports one tool call, so it takes the first one from the vector of OpenAIToolCall, -// hence potentially discarding other tool calls. -impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { - type Error = anyhow::Error; - - fn try_from(cm: &OpenAICompletionChatMessage) -> Result { - let role = ChatMessageRole::from(cm.role.clone()); - let content = match cm.content.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }; - - let function_calls = if let Some(tool_calls) = cm.tool_calls.as_ref() { - let cfc = tool_calls - .into_iter() - .map(|tc| ChatFunctionCall::try_from(tc)) - .collect::, _>>()?; - - Some(cfc) - } else { - None - }; - - let function_call = if let Some(fcs) = function_calls.as_ref() { - match fcs.first() { - Some(fc) => Some(fc), - None => None, - } - .cloned() - } else { - None - }; - - let name = match cm.name.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }; - - Ok(AssistantChatMessage { - content, - role, - name, - function_call, - function_calls, - }) - } -} - -impl TryFrom<&ContentBlock> for OpenAIChatMessageContent { - type Error = anyhow::Error; - - fn try_from(cm: &ContentBlock) -> Result { - match cm { - ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![ - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: t.clone(), - }), - ])), - ContentBlock::Mixed(m) => { - let content: Vec = m - .into_iter() - .map(|mb| match mb { - MixedContent::TextContent(tc) => { - Ok(OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: tc.text.clone(), - })) - } - MixedContent::ImageContent(ic) => { - Ok(OpenAIContentBlock::ImageContent(OpenAIImageContent { - r#type: OpenAIImageContentType::ImageUrl, - image_url: OpenAIImageUrlContent { - url: ic.image_url.url.clone(), - }, - })) - } - }) - .collect::>>()?; - - Ok(OpenAIChatMessageContent::Structured(content)) - } - } - } -} - -impl TryFrom<&String> for OpenAIChatMessageContent { - type Error = anyhow::Error; - - fn try_from(t: &String) -> Result { - Ok(OpenAIChatMessageContent::Structured(vec![ - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text: t.clone(), - }), - ])) - } -} - -impl TryFrom<&ChatMessage> for OpenAIChatMessage { - type Error = anyhow::Error; - - fn try_from(cm: &ChatMessage) -> Result { - match cm { - ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { - content: match &assistant_msg.content { - Some(c) => Some(OpenAIChatMessageContent::try_from(c)?), - None => None, - }, - name: assistant_msg.name.clone(), - role: OpenAIChatMessageRole::from(&assistant_msg.role), - tool_calls: match assistant_msg.function_calls.as_ref() { - Some(fc) => Some( - fc.into_iter() - .map(|f| OpenAIToolCall::try_from(f)) - .collect::, _>>()?, - ), - None => None, - }, - tool_call_id: None, - }), - ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?), - name: None, - role: OpenAIChatMessageRole::Tool, - tool_calls: None, - tool_call_id: Some(function_msg.function_call_id.clone()), - }), - ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?), - name: None, - role: OpenAIChatMessageRole::from(&system_msg.role), - tool_calls: None, - tool_call_id: None, - }), - ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?), - name: user_msg.name.clone(), - role: OpenAIChatMessageRole::from(&user_msg.role), - tool_calls: None, - tool_call_id: None, - }), - } - } -} - -impl TryFrom<&ChatFunction> for OpenAITool { - type Error = anyhow::Error; - - fn try_from(f: &ChatFunction) -> Result { - Ok(OpenAITool { - r#type: OpenAIToolType::Function, - function: OpenAIToolFunction { - name: f.name.clone(), - description: f.description.clone(), - parameters: f.parameters.clone(), - }, - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChatDelta { - pub delta: Value, - pub index: usize, - pub finish_reason: Option, - pub logprobs: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChatChunk { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct InnerError { - pub message: String, - #[serde(alias = "type")] - pub _type: String, - pub param: Option, - pub internal_message: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAIError { - pub error: InnerError, -} - -impl OpenAIError { - pub fn message(&self) -> String { - match self.error.internal_message { - Some(ref msg) => format!( - "OpenAIError: [{}] {} internal_message={}", - self.error._type, self.error.message, msg, - ), - None => format!("OpenAIError: [{}] {}", self.error._type, self.error.message,), - } - } - - pub fn retryable(&self) -> bool { - match self.error._type.as_str() { - "requests" => true, - "server_error" => match &self.error.internal_message { - Some(message) if message.contains("retry") => true, - _ => false, - }, - _ => false, - } - } - - pub fn retryable_streamed(&self, status: StatusCode) -> bool { - if status == StatusCode::TOO_MANY_REQUESTS { - return true; - } - if status.is_server_error() { - return true; - } - match self.error._type.as_str() { - "server_error" => match self.error.internal_message { - Some(_) => true, - None => false, - }, - _ => false, - } - } -} - /// /// Shared streamed/non-streamed chat/completion handling code (used by both OpenAILLM and /// AzureOpenAILLM). @@ -677,8 +153,6 @@ pub async fn streamed_completion( body["stop"] = json!(stop); } - // println!("BODY: {}", body.to_string()); - let client = builder .body(body.to_string()) .reconnect( @@ -729,7 +203,7 @@ pub async fn streamed_completion( { true => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -738,7 +212,7 @@ pub async fn streamed_completion( })?, false => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: None, })?, } @@ -757,517 +231,57 @@ pub async fn streamed_completion( } }; - // UTF-8 length of the prompt (as used by the API for text_offset). - let prompt_len = prompt.chars().count(); - - // Only stream if choices is length 1 but should always be the case. - match event_sender.as_ref() { - Some(sender) => { - let mut text = completion.choices[0].text.clone(); - let mut tokens = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.tokens.clone()), - None => None, - }; - let mut logprobs = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.token_logprobs.clone()), - None => None, - }; - let text_offset = match completion.choices[0].logprobs.as_ref() { - Some(l) => Some(l.text_offset.clone()), - None => None, - }; - if index == 0 && text_offset.is_some() { - let mut token_offset: usize = 0; - for o in text_offset.as_ref().unwrap() { - if *o < prompt_len { - token_offset += 1; - } - } - text = text.chars().skip(prompt_len).collect::(); - tokens = match tokens { - Some(t) => Some(t[token_offset..].to_vec()), - None => None, - }; - logprobs = match logprobs { - Some(l) => Some(l[token_offset..].to_vec()), - None => None, - }; - } - - if text.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": text, - "tokens": tokens, - "logprobs": logprobs, - }, - })); - } - } - None => (), - }; - completions.lock().push(completion); - } - }, - None => { - println!("UNEXPECTED NONE"); - break 'stream; - } - }, - Err(e) => { - match e { - es::Error::UnexpectedResponse(r) => { - let status = StatusCode::from_u16(r.status())?; - let headers = r.headers()?; - let request_id = match headers.get("x-request-id") { - Some(v) => Some(v.to_string()), - None => None, - }; - let b = r.body_bytes().await?; - - let error: Result = serde_json::from_slice(&b); - match error { - Ok(error) => { - match error.retryable_streamed(status) { - true => Err(ModelError { - request_id, - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id, - message: error.message(), - retryable: None, - }), - } - }?, - Err(_) => { - Err(anyhow!( - "Error streaming tokens from OpenAI: status={} data={}", - status, - String::from_utf8_lossy(&b) - ))?; - } - } - } - _ => { - Err(anyhow!("Error streaming tokens from OpenAI: {:?}", e))?; - } - } - break 'stream; - } - } - } - - let completion = { - let mut guard = completions.lock(); - let mut c = match guard.len() { - 0 => Err(anyhow!("No completions received from OpenAI")), - _ => Ok(guard[0].clone()), - }?; - guard.remove(0); - for i in 0..guard.len() { - let a = guard[i].clone(); - if a.choices.len() != c.choices.len() { - Err(anyhow!( - "Inconsistent number of choices in streamed completions" - ))?; - } - for j in 0..c.choices.len() { - c.choices[j].finish_reason = a.choices.get(j).unwrap().finish_reason.clone(); - // OpenAI does the bytes merging for us <3. - c.choices[j].text = format!("{}{}", c.choices[j].text, a.choices[j].text); - - match c.choices[j].logprobs.as_mut() { - Some(c_logprobs) => match a.choices[j].logprobs.as_ref() { - Some(a_logprobs) => { - c_logprobs.tokens.extend(a_logprobs.tokens.clone()); - c_logprobs - .token_logprobs - .extend(a_logprobs.token_logprobs.clone()); - c_logprobs - .text_offset - .extend(a_logprobs.text_offset.clone()); - match c_logprobs.top_logprobs.as_mut() { - Some(c_top_logprobs) => match a_logprobs.top_logprobs.as_ref() { - Some(a_top_logprobs) => { - c_top_logprobs.extend(a_top_logprobs.clone()); - } - None => (), - }, - None => (), - } - } - None => (), - }, - None => (), - } - } - } - c - }; - - Ok((completion, request_id)) -} - -pub async fn completion( - uri: Uri, - api_key: String, - organization_id: Option, - model_id: Option, - prompt: &str, - max_tokens: Option, - temperature: f32, - n: usize, - logprobs: Option, - echo: bool, - stop: &Vec, - frequency_penalty: f32, - presence_penalty: f32, - top_p: f32, - user: Option, -) -> Result<(Completion, Option)> { - // let https = HttpsConnector::new(); - // let cli = Client::builder().build::<_, hyper::Body>(https); - - let mut body = json!({ - "prompt": prompt, - "temperature": temperature, - "n": n, - "logprobs": logprobs, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "top_p": top_p, - }); - if user.is_some() { - body["user"] = json!(user); - } - if model_id.is_some() { - body["model"] = json!(model_id); - } - if let Some(mt) = max_tokens { - body["max_tokens"] = mt.into(); - } - if !stop.is_empty() { - body["stop"] = json!(stop); - } - - match model_id { - None => (), - Some(model_id) => { - body["model"] = json!(model_id); - // `gpt-3.5-turbo-instruct` does not support `echo` - if !model_id.starts_with("gpt-3.5-turbo-instruct") { - body["echo"] = json!(echo); - } - } - }; - - let mut req = reqwest::Client::new() - .post(uri.to_string()) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key.clone())) - .header("api-key", api_key.clone()); - - if let Some(organization_id) = organization_id { - req = req.header("OpenAI-Organization", organization_id); - } - - req = req.json(&body); - - let res = match timeout(Duration::new(180, 0), req.send()).await { - Ok(Ok(res)) => res, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout sending request to OpenAI after 180s"))?, - }; - - let res_headers = res.headers(); - let request_id = match res_headers.get("x-request-id") { - Some(request_id) => Some(request_id.to_str()?.to_string()), - None => None, - }; - - let body = match timeout(Duration::new(180, 0), res.bytes()).await { - Ok(Ok(body)) => body, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout reading response from OpenAI after 180s"))?, - }; - - let mut b: Vec = vec![]; - body.reader().read_to_end(&mut b)?; - let c: &[u8] = &b; - - let completion: Completion = match serde_json::from_slice(c) { - Ok(c) => Ok(c), - Err(_) => { - let error: OpenAIError = serde_json::from_slice(c)?; - match error.retryable() { - true => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - }), - false => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 1, - retries: 1, - }), - }), - } - } - }?; - - Ok((completion, request_id)) -} - -pub async fn streamed_chat_completion( - uri: Uri, - api_key: String, - organization_id: Option, - model_id: Option, - messages: &Vec, - tools: Vec, - tool_choice: Option, - temperature: f32, - top_p: f32, - n: usize, - stop: &Vec, - max_tokens: Option, - presence_penalty: f32, - frequency_penalty: f32, - response_format: Option, - reasoning_effort: Option, - logprobs: Option, - top_logprobs: Option, - user: Option, - event_sender: Option>, -) -> Result<(OpenAIChatCompletion, Option)> { - let url = uri.to_string(); - - let mut builder = match es::ClientBuilder::for_url(url.as_str()) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.method(String::from("POST")).header( - "Authorization", - format!("Bearer {}", api_key.clone()).as_str(), - ) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.header("Content-Type", "application/json") { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - builder = match builder.header("api-key", api_key.clone().as_str()) { - Ok(b) => b, - Err(_) => return Err(anyhow!("Error creating streamed client to OpenAI")), - }; - - if let Some(org_id) = organization_id { - builder = builder - .header("OpenAI-Organization", org_id.as_str()) - .map_err(|_| anyhow!("Error creating streamed client to OpenAI"))?; - } - - let mut body = json!({ - "messages": messages, - "temperature": temperature, - "top_p": top_p, - "n": n, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, - "stream": true, - "stream_options": HashMap::from([("include_usage", true)]), - }); - if user.is_some() { - body["user"] = json!(user); - } - if model_id.is_some() { - body["model"] = json!(model_id); - } - if let Some(mt) = max_tokens { - body["max_tokens"] = mt.into(); - } - if !stop.is_empty() { - body["stop"] = json!(stop); - } - - if tools.len() > 0 { - body["tools"] = json!(tools); - } - if let Some(tool_choice) = tool_choice { - body["tool_choice"] = json!(tool_choice); - } - if let Some(response_format) = response_format { - body["response_format"] = json!({ - "type": response_format, - }); - } - if let Some(reasoning_effort) = reasoning_effort { - body["reasoning_effort"] = json!(reasoning_effort); - } - if let Some(logprobs) = logprobs { - body["logprobs"] = json!(logprobs); - } - if let Some(top_logprobs) = top_logprobs { - body["top_logprobs"] = json!(top_logprobs); - } - - let client = builder - .body(body.to_string()) - .reconnect( - es::ReconnectOptions::reconnect(true) - .retry_initial(false) - .delay(Duration::from_secs(1)) - .backoff_factor(2) - .delay_max(Duration::from_secs(8)) - .build(), - ) - .build(); - - let mut stream = client.stream(); - - let chunks: Arc>> = Arc::new(Mutex::new(Vec::new())); - let mut usage = None; - let mut request_id: Option = None; - - 'stream: loop { - match stream.try_next().await { - Ok(e) => match e { - Some(es::SSE::Connected((_, headers))) => { - request_id = match headers.get("x-request-id") { - Some(v) => Some(v.to_string()), - None => None, - }; - } - Some(es::SSE::Comment(_)) => { - println!("UNEXPECTED COMMENT"); - } - Some(es::SSE::Event(e)) => match e.data.as_str() { - "[DONE]" => { - break 'stream; - } - _ => { - let index = { - let guard = chunks.lock(); - guard.len() - }; - - let chunk: ChatChunk = match serde_json::from_str(e.data.as_str()) { - Ok(c) => c, - Err(err) => { - let error: Result = - serde_json::from_str(e.data.as_str()); - match error { - Ok(error) => { - match error.retryable_streamed(StatusCode::OK) && index == 0 - { - true => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(500), - factor: 2, - retries: 3, - }), - })?, - false => Err(ModelError { - request_id: request_id.clone(), - message: error.message(), - retryable: None, - })?, - } - break 'stream; - } - Err(_) => { - Err(anyhow!( - "OpenAIError: failed parsing streamed \ - completion from OpenAI err={} data={}", - err, - e.data.as_str(), - ))?; - break 'stream; - } - } - } - }; - - // Store usage - match &chunk.usage { - Some(received_usage) => { - usage = Some(received_usage.clone()); - } - None => (), - }; - + // UTF-8 length of the prompt (as used by the API for text_offset). + let prompt_len = prompt.chars().count(); + // Only stream if choices is length 1 but should always be the case. match event_sender.as_ref() { Some(sender) => { - if chunk.choices.len() == 1 { - // we ignore the role for generating events - - // If we get `content` in the delta object we stream "tokens". - match chunk.choices[0].delta.get("content") { - None => (), - Some(content) => match content.as_str() { - None => (), - Some(s) => { - if s.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": s, - }, - })); - } - } - }, + let mut text = completion.choices[0].text.clone(); + let mut tokens = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.tokens.clone()), + None => None, + }; + let mut logprobs = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.token_logprobs.clone()), + None => None, + }; + let text_offset = match completion.choices[0].logprobs.as_ref() { + Some(l) => Some(l.text_offset.clone()), + None => None, + }; + if index == 0 && text_offset.is_some() { + let mut token_offset: usize = 0; + for o in text_offset.as_ref().unwrap() { + if *o < prompt_len { + token_offset += 1; + } + } + text = text.chars().skip(prompt_len).collect::(); + tokens = match tokens { + Some(t) => Some(t[token_offset..].to_vec()), + None => None, + }; + logprobs = match logprobs { + Some(l) => Some(l[token_offset..].to_vec()), + None => None, }; + } - // Emit a `function_call` event per tool_call. - if let Some(tool_calls) = chunk.choices[0] - .delta - .get("tool_calls") - .and_then(|v| v.as_array()) - { - tool_calls.iter().for_each(|tool_call| { - match tool_call.get("function") { - Some(f) => { - if let Some(Value::String(name)) = f.get("name") - { - let _ = sender.send(json!({ - "type": "function_call", - "content": { - "name": name, - }, - })); - } - } - _ => (), - } - }); - } + if text.len() > 0 { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": text, + "tokens": tokens, + "logprobs": logprobs, + }, + })); } } None => (), }; - - if !chunk.choices.is_empty() { - chunks.lock().push(chunk); - } + completions.lock().push(completion); } }, None => { @@ -1292,7 +306,7 @@ pub async fn streamed_chat_completion( match error.retryable_streamed(status) { true => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1301,7 +315,7 @@ pub async fn streamed_chat_completion( }), false => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: None, }), } @@ -1324,200 +338,87 @@ pub async fn streamed_chat_completion( } } - let mut completion = { - let guard = chunks.lock(); - let f = match guard.len() { - 0 => Err(anyhow!("No chunks received from OpenAI")), + let completion = { + let mut guard = completions.lock(); + let mut c = match guard.len() { + 0 => Err(anyhow!("No completions received from OpenAI")), _ => Ok(guard[0].clone()), }?; - - // merge logprobs from all choices of all chunks - let logprobs: Vec = guard - .iter() - .flat_map(|chunk| { - chunk - .choices - .iter() - .filter_map(|choice| choice.logprobs.as_ref().map(|lp| lp.content.clone())) - }) - .flatten() - .collect(); - - let mut c = OpenAIChatCompletion { - id: f.id.clone(), - object: f.object.clone(), - created: f.created, - choices: f - .choices - .iter() - .map(|c| OpenAIChatChoice { - message: OpenAICompletionChatMessage { - content: Some("".to_string()), - name: None, - role: OpenAIChatMessageRole::System, - tool_calls: None, - tool_call_id: None, - }, - index: c.index, - finish_reason: None, - logprobs: match logprobs.len() { - 0 => None, - _ => Some(OpenAIChatChoiceLogprobs { - content: logprobs.clone(), - }), - }, - }) - .collect::>(), - usage, - }; - + guard.remove(0); for i in 0..guard.len() { let a = guard[i].clone(); - if a.choices.len() != f.choices.len() { - Err(anyhow!("Inconsistent number of choices in streamed chunks"))?; + if a.choices.len() != c.choices.len() { + Err(anyhow!( + "Inconsistent number of choices in streamed completions" + ))?; } - for j in 0..a.choices.len() { - match a.choices.get(j).unwrap().finish_reason.clone() { - None => (), - Some(f) => c.choices[j].finish_reason = Some(f), - }; - - match a.choices[j].delta.get("role") { - None => (), - Some(role) => match role.as_str() { - None => (), - Some(r) => { - c.choices[j].message.role = OpenAIChatMessageRole::from_str(r)?; - } - }, - }; + for j in 0..c.choices.len() { + c.choices[j].finish_reason = a.choices.get(j).unwrap().finish_reason.clone(); + // OpenAI does the bytes merging for us <3. + c.choices[j].text = format!("{}{}", c.choices[j].text, a.choices[j].text); - match a.choices[j].delta.get("content") { - None => (), - Some(content) => match content.as_str() { - None => (), - Some(s) => { - c.choices[j].message.content = Some(format!( - "{}{}", - c.choices[j] - .message - .content - .as_ref() - .unwrap_or(&String::new()), - s - )); - } - }, - }; - - if let Some(tool_calls) = a.choices[j] - .delta - .get("tool_calls") - .and_then(|v| v.as_array()) - { - for tool_call in tool_calls { - match ( - tool_call.get("type").and_then(|v| v.as_str()), - tool_call.get("id").and_then(|v| v.as_str()), - tool_call.get("function"), - ) { - (Some("function"), Some(id), Some(f)) => { - if let Some(Value::String(name)) = f.get("name") { - c.choices[j] - .message - .tool_calls - .get_or_insert_with(Vec::new) - .push(OpenAIToolCall { - id: Some(id.to_string()), - r#type: OpenAIToolType::Function, - function: OpenAIToolCallFunction { - name: name.clone(), - arguments: String::new(), - }, - }); - } - } - (None, None, Some(f)) => { - if let (Some(Value::Number(idx)), Some(Value::String(a))) = - (tool_call.get("index"), f.get("arguments")) - { - let index: usize = idx - .as_u64() - .ok_or_else(|| anyhow!("Missing index for tools"))? - .try_into() - .map_err(|e| { - anyhow!("Invalid index value for tools: {:?}", e) - })?; - - let tool_calls = c.choices[j] - .message - .tool_calls - .as_mut() - .ok_or(anyhow!("Missing tool calls"))?; - - if index >= tool_calls.len() { - return Err(anyhow!( - "Index out-of-bound for tool_calls: {}", - index - )); + match c.choices[j].logprobs.as_mut() { + Some(c_logprobs) => match a.choices[j].logprobs.as_ref() { + Some(a_logprobs) => { + c_logprobs.tokens.extend(a_logprobs.tokens.clone()); + c_logprobs + .token_logprobs + .extend(a_logprobs.token_logprobs.clone()); + c_logprobs + .text_offset + .extend(a_logprobs.text_offset.clone()); + match c_logprobs.top_logprobs.as_mut() { + Some(c_top_logprobs) => match a_logprobs.top_logprobs.as_ref() { + Some(a_top_logprobs) => { + c_top_logprobs.extend(a_top_logprobs.clone()); } - - tool_calls[index].function.arguments += a; - } + None => (), + }, + None => (), } - _ => (), } - } + None => (), + }, + None => (), } } } c }; - // for all messages, edit the content and strip leading and trailing spaces and \n - for m in completion.choices.iter_mut() { - m.message.content = match m.message.content.as_ref() { - None => None, - Some(c) => Some(c.trim().to_string()), - }; - } - Ok((completion, request_id)) } -pub async fn chat_completion( +pub async fn completion( uri: Uri, api_key: String, organization_id: Option, model_id: Option, - messages: &Vec, - tools: Vec, - tool_choice: Option, + prompt: &str, + max_tokens: Option, temperature: f32, - top_p: f32, n: usize, + logprobs: Option, + echo: bool, stop: &Vec, - max_tokens: Option, - presence_penalty: f32, frequency_penalty: f32, - response_format: Option, - reasoning_effort: Option, - logprobs: Option, - top_logprobs: Option, + presence_penalty: f32, + top_p: f32, user: Option, -) -> Result<(OpenAIChatCompletion, Option)> { +) -> Result<(Completion, Option)> { let mut body = json!({ - "messages": messages, + "prompt": prompt, "temperature": temperature, - "top_p": top_p, "n": n, - "presence_penalty": presence_penalty, + "logprobs": logprobs, "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "top_p": top_p, }); if user.is_some() { body["user"] = json!(user); } - if let Some(model_id) = model_id { + if model_id.is_some() { body["model"] = json!(model_id); } if let Some(mt) = max_tokens { @@ -1527,43 +428,28 @@ pub async fn chat_completion( body["stop"] = json!(stop); } - if let Some(response_format) = response_format { - body["response_format"] = json!({ - "type": response_format, - }); - } - if tools.len() > 0 { - body["tools"] = json!(tools); - } - if let Some(tool_choice) = tool_choice { - body["tool_choice"] = json!(tool_choice); - } - if let Some(reasoning_effort) = reasoning_effort { - body["reasoning_effort"] = json!(reasoning_effort); - } - if let Some(logprobs) = logprobs { - body["logprobs"] = json!(logprobs); - } - if let Some(top_logprobs) = top_logprobs { - body["top_logprobs"] = json!(top_logprobs); - } + match model_id { + None => (), + Some(model_id) => { + body["model"] = json!(model_id); + // `gpt-3.5-turbo-instruct` does not support `echo` + if !model_id.starts_with("gpt-3.5-turbo-instruct") { + body["echo"] = json!(echo); + } + } + }; let mut req = reqwest::Client::new() .post(uri.to_string()) .header("Content-Type", "application/json") - // This one is for `openai`. .header("Authorization", format!("Bearer {}", api_key.clone())) - // This one is for `azure_openai`. .header("api-key", api_key.clone()); if let Some(organization_id) = organization_id { - req = req.header( - "OpenAI-Organization", - &format!("{}", organization_id.clone()), - ); + req = req.header("OpenAI-Organization", organization_id); } - let req = req.json(&body); + req = req.json(&body); let res = match timeout(Duration::new(180, 0), req.send()).await { Ok(Ok(res)) => res, @@ -1587,14 +473,14 @@ pub async fn chat_completion( body.reader().read_to_end(&mut b)?; let c: &[u8] = &b; - let mut completion: OpenAIChatCompletion = match serde_json::from_slice(c) { + let completion: Completion = match serde_json::from_slice(c) { Ok(c) => Ok(c), Err(_) => { let error: OpenAIError = serde_json::from_slice(c)?; match error.retryable() { true => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1603,7 +489,7 @@ pub async fn chat_completion( }), false => Err(ModelError { request_id: request_id.clone(), - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 1, @@ -1614,45 +500,9 @@ pub async fn chat_completion( } }?; - // for all messages, edit the content and strip leading and trailing spaces and \n - for m in completion.choices.iter_mut() { - m.message.content = match m.message.content.as_ref() { - None => None, - Some(c) => Some(c.trim().to_string()), - }; - } - Ok((completion, request_id)) } -pub fn logprobs_from_choices(choices: &Vec) -> Option> { - let lp: Vec = choices - .iter() - .filter_map(|choice| choice.logprobs.as_ref()) - .flat_map(|lp| { - lp.content.iter().map(|content_logprob| LLMChatLogprob { - token: content_logprob.token.clone(), - logprob: content_logprob.logprob, - top_logprobs: match content_logprob.top_logprobs.len() { - 0 => None, - _ => Some( - content_logprob - .top_logprobs - .iter() - .map(|top_logprob| top_logprob.clone().into()) - .collect(), - ), - }, - }) - }) - .collect(); - - match lp.len() { - 0 => None, - _ => Some(lp), - } -} - /// /// Shared streamed/non-streamed chat/completion handling code (used by both OpenAILLM and /// AzureOpenAILLM). @@ -1735,7 +585,7 @@ pub async fn embed( match error.retryable() { true => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 2, @@ -1744,7 +594,7 @@ pub async fn embed( }), false => Err(ModelError { request_id, - message: error.message(), + message: error.with_provider("OpenAI").message(), retryable: Some(ModelErrorRetryOptions { sleep: Duration::from_millis(500), factor: 1, @@ -1822,29 +672,6 @@ impl OpenAILLM { } } -pub fn to_openai_messages( - messages: &Vec, - model_id: &str, -) -> Result, anyhow::Error> { - let mut oai_messages = messages - .iter() - .map(|m| OpenAIChatMessage::try_from(m)) - .collect::>>()? - .into_iter() - // [o1-mini] O1 mini does not support system messages, so we filter them out. - .filter(|m| m.role != OpenAIChatMessageRole::System || !model_id.starts_with("o1-mini")) - .collect::>(); - - // [o1] O1 uses `developer` messages instead of `system` messages. - for m in oai_messages.iter_mut() { - if m.role == OpenAIChatMessageRole::System && model_id.starts_with("o1") { - m.role = OpenAIChatMessageRole::Developer; - } - } - - Ok(oai_messages) -} - #[async_trait] impl LLM for OpenAILLM { fn id(&self) -> String { @@ -2100,7 +927,7 @@ impl LLM for OpenAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -2108,162 +935,40 @@ impl LLM for OpenAILLM { extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let (openai_org_id, openai_user, response_format, reasoning_effort) = match &extras { - None => (None, None, None, None), - Some(v) => ( - match v.get("openai_organization_id") { - Some(Value::String(o)) => Some(o.to_string()), - _ => None, - }, - match v.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - match v.get("response_format") { - Some(Value::String(f)) => Some(f.to_string()), - _ => None, - }, - match v.get("reasoning_effort") { - Some(Value::String(r)) => Some(r.to_string()), - _ => None, - }, - ), - }; - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - let openai_messages = to_openai_messages(messages, &self.id)?; - - // [o1] Hack for OpenAI `o1*` models to simulate streaming. - let is_streaming = event_sender.is_some(); let model_is_o1 = self.id.as_str().starts_with("o1"); - - let (c, request_id) = if !model_is_o1 && is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - openai_org_id, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - // [o1] O1 models do not support custom temperature. - if !model_is_o1 { temperature } else { 1.0 }, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - openai_org_id, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - // [o1] O1 models do not support custom temperature. - if !model_is_o1 { temperature } else { 1.0 }, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - response_format, - reasoning_effort, - logprobs, - top_logprobs, - openai_user, - ) - .await? - }; - - // [o1] Hack for OpenAI `o1*` models to simulate streaming. - if model_is_o1 && is_streaming { - let sender = event_sender.as_ref().unwrap(); - for choice in &c.choices { - if let Some(content) = &choice.message.content { - // Split content into smaller chunks to simulate streaming. - for chunk in content - .chars() - .collect::>() - .chunks(4) - .map(|c| c.iter().collect::()) - { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": chunk, - }, - })); - // Add a small delay to simulate real-time streaming. - tokio::time::sleep(std::time::Duration::from_millis(20)).await; - } - } - } - } - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::OpenAI.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + let model_is_o1_mini = self.id.as_str().starts_with("o1-mini"); + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + &messages, + functions, + function_call, + if model_is_o1 { 1.0 } else { temperature }, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + extras, + event_sender, + model_is_o1, // disable provider streaming if model is o1. + // If model is o1-mini, we remove system messages. + // If model is o1, we replace system messages with developer messages. + if model_is_o1_mini { + TransformSystemMessages::Remove + } else if model_is_o1 { + TransformSystemMessages::ReplaceWithDeveloper + } else { + TransformSystemMessages::Keep + }, + "OpenAI".to_string(), + false, // Don't squash text contents. + ) + .await } } @@ -2462,33 +1167,6 @@ impl Provider for OpenAIProvider { ) .await?; - // let mut embedder = self.embedder(String::from("text-embedding-ada-002")); - // embedder.initialize(Credentials::new()).await?; - - // let _v = embedder.embed("Hello 😊", None).await?; - // println!("EMBEDDING SIZE: {}", v.vector.len()); - - // llm = self.llm(String::from("gpt-3.5-turbo")); - // llm.initialize(Credentials::new()).await?; - - // let messages = vec![ - // // ChatMessage { - // // role: String::from("system"), - // // content: String::from( - // // "You're a an assistant. Answer as concisely and precisely as possible.", - // // ), - // // }, - // ChatMessage { - // role: String::from("user"), - // content: String::from("How can I calculate the area of a circle?"), - // }, - // ]; - - // let c = llm - // .chat(&messages, 0.7, None, 1, &vec![], None, None, None, None) - // .await?; - // println!("CHAT COMPLETION SIZE: {:?}", c); - utils::done("Test successfully completed! OpenAI is ready to use."); Ok(()) diff --git a/core/src/providers/openai_compatible_helpers.rs b/core/src/providers/openai_compatible_helpers.rs new file mode 100644 index 000000000000..f80ff001394d --- /dev/null +++ b/core/src/providers/openai_compatible_helpers.rs @@ -0,0 +1,1531 @@ +use std::{collections::HashMap, str::FromStr, time::Duration}; + +use crate::{ + providers::{llm::LLMTokenUsage, provider::ProviderID}, + utils::{self, ParseError}, +}; +use anyhow::{anyhow, Result}; +use eventsource_client as es; +use eventsource_client::Client as ESClient; +use futures::TryStreamExt; +use http::StatusCode; +use hyper::{body::Buf, Uri}; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::io::prelude::*; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; +use tokio::time::timeout; + +use super::{ + chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent}, + llm::{ + ChatFunction, ChatFunctionCall, ChatMessageRole, LLMChatGeneration, LLMChatLogprob, + TopLogprob, + }, + provider::{ModelError, ModelErrorRetryOptions}, +}; + +// Input types. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIToolType { + Function, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIFunction { + name: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIFunctionCall { + r#type: OpenAIToolType, + function: OpenAIFunction, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIToolControl { + Auto, + Required, + None, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIToolChoice { + OpenAIToolControl(OpenAIToolControl), + OpenAIFunctionCall(OpenAIFunctionCall), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolFunction { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAITool { + pub r#type: OpenAIToolType, + pub function: OpenAIToolFunction, +} + +impl FromStr for OpenAIToolChoice { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match s { + "auto" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::Auto)), + "any" => Ok(OpenAIToolChoice::OpenAIToolControl( + OpenAIToolControl::Required, + )), + "none" => Ok(OpenAIToolChoice::OpenAIToolControl(OpenAIToolControl::None)), + _ => { + let function = OpenAIFunctionCall { + r#type: OpenAIToolType::Function, + function: OpenAIFunction { + name: s.to_string(), + }, + }; + Ok(OpenAIToolChoice::OpenAIFunctionCall(function)) + } + } + } +} + +// Outputs types. + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Usage { + pub prompt_tokens: u64, + pub completion_tokens: Option, + pub total_tokens: u64, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolCallFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + r#type: OpenAIToolType, + pub function: OpenAIToolCallFunction, +} + +impl TryFrom<&ChatFunctionCall> for OpenAIToolCall { + type Error = anyhow::Error; + + fn try_from(cf: &ChatFunctionCall) -> Result { + Ok(OpenAIToolCall { + id: Some(cf.id.clone()), + r#type: OpenAIToolType::Function, + function: OpenAIToolCallFunction { + name: cf.name.clone(), + arguments: cf.arguments.clone(), + }, + }) + } +} + +impl TryFrom<&OpenAIToolCall> for ChatFunctionCall { + type Error = anyhow::Error; + + fn try_from(tc: &OpenAIToolCall) -> Result { + // Some providers don't provide a function call ID (eg google_ai_studio) + let id = tc + .id + .clone() + .unwrap_or(format!("fc_{}", utils::new_id()[0..9].to_string())); + + Ok(ChatFunctionCall { + id, + name: tc.function.name.clone(), + arguments: tc.function.arguments.clone(), + }) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAIChatMessageRole { + Assistant, + Function, + System, + Developer, + Tool, + User, +} + +impl From<&ChatMessageRole> for OpenAIChatMessageRole { + fn from(role: &ChatMessageRole) -> Self { + match role { + ChatMessageRole::Assistant => OpenAIChatMessageRole::Assistant, + ChatMessageRole::Function => OpenAIChatMessageRole::Function, + ChatMessageRole::System => OpenAIChatMessageRole::System, + ChatMessageRole::User => OpenAIChatMessageRole::User, + } + } +} + +impl FromStr for OpenAIChatMessageRole { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "system" => Ok(OpenAIChatMessageRole::System), + "user" => Ok(OpenAIChatMessageRole::User), + "assistant" => Ok(OpenAIChatMessageRole::Assistant), + "function" => Ok(OpenAIChatMessageRole::Tool), + _ => Err(ParseError::with_message("Unknown OpenAIChatMessageRole"))?, + } + } +} + +impl From for ChatMessageRole { + fn from(value: OpenAIChatMessageRole) -> Self { + match value { + OpenAIChatMessageRole::Assistant => ChatMessageRole::Assistant, + OpenAIChatMessageRole::Function => ChatMessageRole::Function, + OpenAIChatMessageRole::System => ChatMessageRole::System, + OpenAIChatMessageRole::Developer => ChatMessageRole::System, + OpenAIChatMessageRole::Tool => ChatMessageRole::Function, + OpenAIChatMessageRole::User => ChatMessageRole::User, + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAITextContentType { + Text, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAITextContent { + #[serde(rename = "type")] + pub r#type: OpenAITextContentType, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageUrlContent { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAIImageContentType { + ImageUrl, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageContent { + pub r#type: OpenAIImageContentType, + pub image_url: OpenAIImageUrlContent, +} + +// Define an enum for mixed content +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIContentBlock { + TextContent(OpenAITextContent), + ImageContent(OpenAIImageContent), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIChatMessageContent { + Structured(Vec), + String(String), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIChatMessage { + pub role: OpenAIChatMessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAICompletionChatMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub role: OpenAIChatMessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAITopLogprob { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoiceLogprob { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +impl From for TopLogprob { + fn from(top_logprob: OpenAITopLogprob) -> Self { + TopLogprob { + token: top_logprob.token, + logprob: top_logprob.logprob, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoiceLogprobs { + pub content: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatChoice { + pub message: OpenAICompletionChatMessage, + pub index: usize, + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIChatCompletion { + pub id: Option, + pub object: String, + pub created: u64, + pub choices: Vec, + pub usage: Option, +} + +// This code performs a type conversion with information loss when converting to ChatFunctionCall. +// It only supports one tool call, so it takes the first one from the vector of OpenAIToolCall, +// hence potentially discarding other tool calls. +impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &OpenAICompletionChatMessage) -> Result { + let role = ChatMessageRole::from(cm.role.clone()); + let content = match cm.content.as_ref() { + Some(c) => Some(c.clone()), + None => None, + }; + + let function_calls = if let Some(tool_calls) = cm.tool_calls.as_ref() { + let cfc = tool_calls + .into_iter() + .map(|tc| ChatFunctionCall::try_from(tc)) + .collect::, _>>()?; + + Some(cfc) + } else { + None + }; + + let function_call = if let Some(fcs) = function_calls.as_ref() { + match fcs.first() { + Some(fc) => Some(fc), + None => None, + } + .cloned() + } else { + None + }; + + let name = match cm.name.as_ref() { + Some(c) => Some(c.clone()), + None => None, + }; + + Ok(AssistantChatMessage { + content, + role, + name, + function_call, + function_calls, + }) + } +} + +impl TryFrom<&ContentBlock> for OpenAIChatMessageContent { + type Error = anyhow::Error; + + fn try_from(cm: &ContentBlock) -> Result { + match cm { + ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])), + ContentBlock::Mixed(m) => { + let content: Vec = m + .into_iter() + .map(|mb| match mb { + MixedContent::TextContent(tc) => { + Ok(OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: tc.text.clone(), + })) + } + MixedContent::ImageContent(ic) => { + Ok(OpenAIContentBlock::ImageContent(OpenAIImageContent { + r#type: OpenAIImageContentType::ImageUrl, + image_url: OpenAIImageUrlContent { + url: ic.image_url.url.clone(), + }, + })) + } + }) + .collect::>>()?; + + Ok(OpenAIChatMessageContent::Structured(content)) + } + } + } +} + +impl TryFrom<&String> for OpenAIChatMessageContent { + type Error = anyhow::Error; + + fn try_from(t: &String) -> Result { + Ok(OpenAIChatMessageContent::Structured(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])) + } +} + +impl TryFrom<&ChatMessage> for OpenAIChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &ChatMessage) -> Result { + match cm { + ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { + content: match &assistant_msg.content { + Some(c) => Some(OpenAIChatMessageContent::try_from(c)?), + None => None, + }, + name: assistant_msg.name.clone(), + role: OpenAIChatMessageRole::from(&assistant_msg.role), + tool_calls: match assistant_msg.function_calls.as_ref() { + Some(fc) => Some( + fc.into_iter() + .map(|f| OpenAIToolCall::try_from(f)) + .collect::, _>>()?, + ), + None => None, + }, + tool_call_id: None, + }), + ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?), + name: None, + role: OpenAIChatMessageRole::Tool, + tool_calls: None, + tool_call_id: Some(function_msg.function_call_id.clone()), + }), + ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?), + name: None, + role: OpenAIChatMessageRole::from(&system_msg.role), + tool_calls: None, + tool_call_id: None, + }), + ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?), + name: user_msg.name.clone(), + role: OpenAIChatMessageRole::from(&user_msg.role), + tool_calls: None, + tool_call_id: None, + }), + } + } +} + +impl TryFrom<&ChatFunction> for OpenAITool { + type Error = anyhow::Error; + + fn try_from(f: &ChatFunction) -> Result { + Ok(OpenAITool { + r#type: OpenAIToolType::Function, + function: OpenAIToolFunction { + name: f.name.clone(), + description: f.description.clone(), + parameters: f.parameters.clone(), + }, + }) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatDelta { + pub delta: Value, + pub index: usize, + pub finish_reason: Option, + pub logprobs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatChunk { + pub id: Option, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InnerError { + pub message: String, + #[serde(alias = "type")] + pub _type: String, + pub param: Option, + pub internal_message: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct OpenAIError { + pub error: InnerError, +} + +pub struct OpenAICompatibleError { + pub provider: String, + pub error: OpenAIError, +} + +impl OpenAIError { + pub fn with_provider(self, provider: &str) -> OpenAICompatibleError { + OpenAICompatibleError { + provider: provider.to_string(), + error: self, + } + } + + pub fn retryable(&self) -> bool { + match self.error._type.as_str() { + "requests" => true, + "server_error" => match &self.error.internal_message { + Some(message) if message.contains("retry") => true, + _ => false, + }, + _ => false, + } + } + + pub fn retryable_streamed(&self, status: StatusCode) -> bool { + if status == StatusCode::TOO_MANY_REQUESTS { + return true; + } + if status.is_server_error() { + return true; + } + match self.error._type.as_str() { + "server_error" => match self.error.internal_message { + Some(_) => true, + None => false, + }, + _ => false, + } + } +} + +impl OpenAICompatibleError { + pub fn message(&self) -> String { + match &self.error.error.internal_message { + Some(ref msg) => format!( + "{}Error: [{}] {} internal_message={}", + self.provider, self.error.error._type, self.error.error.message, msg, + ), + None => format!( + "{}Error: [{}] {}", + self.provider, self.error.error._type, self.error.error.message, + ), + } + } + + pub fn retryable(&self) -> bool { + self.error.retryable() + } + + pub fn retryable_streamed(&self, status: StatusCode) -> bool { + self.error.retryable_streamed(status) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransformSystemMessages { + Remove, + ReplaceWithDeveloper, + Keep, +} + +pub async fn openai_compatible_chat_completion( + uri: Uri, + model_id: String, + api_key: String, + messages: &Vec, + functions: &Vec, + function_call: Option, + temperature: f32, + top_p: Option, + n: usize, + stop: &Vec, + mut max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + logprobs: Option, + top_logprobs: Option, + openai_extras: Option, + event_sender: Option>, + disable_provider_streaming: bool, + transform_system_messages: TransformSystemMessages, + provider_name: String, + squash_text_contents: bool, +) -> Result { + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = None; + } + } + + let (openai_org_id, openai_user, response_format, reasoning_effort) = match &openai_extras { + None => (None, None, None, None), + Some(v) => ( + match v.get("openai_organization_id") { + Some(Value::String(o)) => Some(o.to_string()), + _ => None, + }, + match v.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + match v.get("response_format") { + Some(Value::String(f)) => Some(f.to_string()), + _ => None, + }, + match v.get("reasoning_effort") { + Some(Value::String(r)) => Some(r.to_string()), + _ => None, + }, + ), + }; + + let tool_choice = match function_call.as_ref() { + Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), + None => None, + }; + + let tools = functions + .iter() + .map(OpenAITool::try_from) + .collect::, _>>()?; + + let openai_messages = + to_openai_messages(messages, transform_system_messages, squash_text_contents)?; + + let stream_output = event_sender.is_some(); + + let (c, request_id) = if !disable_provider_streaming && stream_output { + streamed_chat_completion( + uri, + api_key, + openai_org_id, + Some(model_id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + response_format, + reasoning_effort, + logprobs, + top_logprobs, + openai_user, + event_sender.clone(), + provider_name, + ) + .await? + } else { + chat_completion( + uri, + api_key, + openai_org_id, + Some(model_id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + response_format, + reasoning_effort, + logprobs, + top_logprobs, + openai_user, + provider_name, + ) + .await? + }; + + // We support streaming the output in the event sender, even if we're not + // using streaming from the provider. + if stream_output && disable_provider_streaming { + let sender = event_sender.as_ref().unwrap(); + for choice in &c.choices { + if let Some(content) = &choice.message.content { + // Split content into smaller chunks to simulate streaming. + for chunk in content + .chars() + .collect::>() + .chunks(4) + .map(|c| c.iter().collect::()) + { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": chunk, + }, + })); + // Add a small delay to simulate real-time streaming. + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + } + } + } + + assert!(c.choices.len() > 0); + + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::OpenAI.to_string(), + model: model_id.clone(), + completions: c + .choices + .iter() + .map(|c| AssistantChatMessage::try_from(&c.message)) + .collect::>>()?, + usage: c.usage.map(|usage| LLMTokenUsage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens.unwrap_or(0), + }), + provider_request_id: request_id, + logprobs: logprobs_from_choices(&c.choices), + }) +} + +fn to_openai_messages( + messages: &Vec, + transform_system_messages: TransformSystemMessages, + squash_text_contents: bool, +) -> Result, anyhow::Error> { + let mut oai_messages = messages + .into_iter() + // First convert to OpenAI chat messages. + .map(|m| OpenAIChatMessage::try_from(m)) + .collect::>>()? + .into_iter() + // Decide which content format to use for each message (structured or string). + // If there are images, we need to use structured format. + // If there is a single text content, we can always use string format (equivalent and compatible everywhere). + // Otherwise, if there are multiple text contents, we either squash them or keep the structured format, + // depending on the `squash_text_contents` flag. + .map(|m| match m.content { + None => m, + Some(OpenAIChatMessageContent::String(_)) => m, + Some(OpenAIChatMessageContent::Structured(contents)) => { + let all_contents_are_text = contents + .iter() + .all(|c| matches!(c, OpenAIContentBlock::TextContent(_))); + + OpenAIChatMessage { + role: m.role, + name: m.name, + tool_call_id: m.tool_call_id, + tool_calls: m.tool_calls, + content: match (contents.len(), all_contents_are_text, squash_text_contents) { + // Case 0: there's no content => return None + (0, _, _) => None, + // Case 1: there's only a single text content => use string format + (1, true, _) => Some(OpenAIChatMessageContent::String( + match contents.into_iter().next().unwrap() { + OpenAIContentBlock::TextContent(tc) => tc.text.clone(), + _ => unreachable!(), + }, + )), + // Case 2: There's more than one content, all contents are text and we want to squash them => squash them + (_, true, true) => Some(OpenAIChatMessageContent::String( + contents + .into_iter() + .map(|c| match c { + OpenAIContentBlock::TextContent(tc) => tc.text.clone(), + _ => unreachable!(), + }) + .collect::>() + .join("\n"), + )), + // Case 3: there's more than one content, the content isn't text or we don't want to squash them => keep structured format + (_, _, _) => Some(OpenAIChatMessageContent::Structured(contents)), + }, + } + } + }) + // Truncate the tool_ids to 40 characters. + .map(|m| { + fn truncate_id(id: Option) -> Option { + id.map(|id| id.chars().take(40).collect::()) + } + OpenAIChatMessage { + role: m.role, + name: m.name, + tool_call_id: truncate_id(m.tool_call_id), + tool_calls: m.tool_calls.map(|tool_calls| { + tool_calls + .into_iter() + .map(|tc| OpenAIToolCall { + id: truncate_id(tc.id), + function: tc.function, + r#type: tc.r#type, + }) + .collect() + }), + content: m.content, + } + }) + // Remove system messages if requested. + // Some models don't support system messages, so we need this to be + // configurable. + .filter(|m| { + transform_system_messages != TransformSystemMessages::Remove + || m.role != OpenAIChatMessageRole::System + }) + .collect::>(); + + // Replace system messages with developer messages if requested. + // Some newer models no longer support "system" as a role and support a "developer" role. + for m in oai_messages.iter_mut() { + if m.role == OpenAIChatMessageRole::System + && transform_system_messages == TransformSystemMessages::ReplaceWithDeveloper + { + m.role = OpenAIChatMessageRole::Developer; + } + } + + Ok(oai_messages) +} + +async fn streamed_chat_completion( + uri: Uri, + api_key: String, + organization_id: Option, + model_id: Option, + messages: &Vec, + tools: Vec, + tool_choice: Option, + temperature: f32, + top_p: f32, + n: usize, + stop: &Vec, + max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + response_format: Option, + reasoning_effort: Option, + logprobs: Option, + top_logprobs: Option, + user: Option, + event_sender: Option>, + provider_name: String, +) -> Result<(OpenAIChatCompletion, Option)> { + let url = uri.to_string(); + + let mut builder = match es::ClientBuilder::for_url(url.as_str()) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.method(String::from("POST")).header( + "Authorization", + format!("Bearer {}", api_key.clone()).as_str(), + ) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.header("Content-Type", "application/json") { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + builder = match builder.header("api-key", api_key.clone().as_str()) { + Ok(b) => b, + Err(_) => { + return Err(anyhow!(format!( + "Error creating streamed client to {}", + provider_name + ))) + } + }; + + if let Some(org_id) = organization_id { + builder = builder + .header("OpenAI-Organization", org_id.as_str()) + .map_err(|_| { + anyhow!(format!( + "Error creating streamed client to {}", + provider_name + )) + })?; + } + + let mut body = json!({ + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": true, + "stream_options": HashMap::from([("include_usage", true)]), + }); + if let Some(presence_penalty) = presence_penalty { + body["presence_penalty"] = json!(presence_penalty); + } + if let Some(frequency_penalty) = frequency_penalty { + body["frequency_penalty"] = json!(frequency_penalty); + } + if user.is_some() { + body["user"] = json!(user); + } + if model_id.is_some() { + body["model"] = json!(model_id); + } + if let Some(mt) = max_tokens { + body["max_tokens"] = mt.into(); + } + if !stop.is_empty() { + body["stop"] = json!(stop); + } + + if tools.len() > 0 { + body["tools"] = json!(tools); + } + if let Some(tool_choice) = tool_choice { + body["tool_choice"] = json!(tool_choice); + } + if let Some(response_format) = response_format { + body["response_format"] = json!({ + "type": response_format, + }); + } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } + if let Some(logprobs) = logprobs { + body["logprobs"] = json!(logprobs); + } + if let Some(top_logprobs) = top_logprobs { + body["top_logprobs"] = json!(top_logprobs); + } + + let client = builder + .body(body.to_string()) + .reconnect( + es::ReconnectOptions::reconnect(true) + .retry_initial(false) + .delay(Duration::from_secs(1)) + .backoff_factor(2) + .delay_max(Duration::from_secs(8)) + .build(), + ) + .build(); + + let mut stream = client.stream(); + + let chunks: Arc>> = Arc::new(Mutex::new(Vec::new())); + let mut usage = None; + let mut request_id: Option = None; + + 'stream: loop { + match stream.try_next().await { + Ok(e) => match e { + Some(es::SSE::Connected((_, headers))) => { + request_id = match headers.get("x-request-id") { + Some(v) => Some(v.to_string()), + None => None, + }; + } + Some(es::SSE::Comment(_)) => { + println!("UNEXPECTED COMMENT"); + } + Some(es::SSE::Event(e)) => match e.data.as_str() { + "[DONE]" => { + break 'stream; + } + _ => { + let index = { + let guard = chunks.lock(); + guard.len() + }; + + let chunk: ChatChunk = match serde_json::from_str(e.data.as_str()) { + Ok(c) => c, + Err(err) => { + let error: Result = + serde_json::from_str(e.data.as_str()); + match error { + Ok(error) => { + match error.retryable_streamed(StatusCode::OK) && index == 0 + { + true => Err(ModelError { + request_id: request_id.clone(), + message: error + .with_provider(&provider_name) + .message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + })?, + false => Err(ModelError { + request_id: request_id.clone(), + message: error + .with_provider(&provider_name) + .message(), + retryable: None, + })?, + } + break 'stream; + } + Err(_) => { + Err(anyhow!(format!( + "{}Error: failed parsing streamed \ + completion from {} err={} data={}", + provider_name, + provider_name, + err, + e.data.as_str(), + )))?; + break 'stream; + } + } + } + }; + + // Store usage + match &chunk.usage { + Some(received_usage) => { + usage = Some(received_usage.clone()); + } + None => (), + }; + + // Only stream if choices is length 1 but should always be the case. + match event_sender.as_ref() { + Some(sender) => { + if chunk.choices.len() == 1 { + // we ignore the role for generating events + + // If we get `content` in the delta object we stream "tokens". + match chunk.choices[0].delta.get("content") { + None => (), + Some(content) => match content.as_str() { + None => (), + Some(s) => { + if s.len() > 0 { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": s, + }, + })); + } + } + }, + }; + + // Emit a `function_call` event per tool_call. + if let Some(tool_calls) = chunk.choices[0] + .delta + .get("tool_calls") + .and_then(|v| v.as_array()) + { + tool_calls.iter().for_each(|tool_call| { + match tool_call.get("function") { + Some(f) => { + if let Some(Value::String(name)) = f.get("name") + { + let _ = sender.send(json!({ + "type": "function_call", + "content": { + "name": name, + }, + })); + } + } + _ => (), + } + }); + } + } + } + None => (), + }; + + if !chunk.choices.is_empty() { + chunks.lock().push(chunk); + } + } + }, + None => { + println!("UNEXPECTED NONE"); + break 'stream; + } + }, + Err(e) => { + match e { + es::Error::UnexpectedResponse(r) => { + let status = StatusCode::from_u16(r.status())?; + let headers = r.headers()?; + let request_id = match headers.get("x-request-id") { + Some(v) => Some(v.to_string()), + None => None, + }; + let b = r.body_bytes().await?; + + let error: Result = serde_json::from_slice(&b); + match error { + Ok(error) => { + match error.retryable_streamed(status) { + true => Err(ModelError { + request_id, + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + }), + false => Err(ModelError { + request_id, + message: error.with_provider(&provider_name).message(), + retryable: None, + }), + } + }?, + Err(_) => { + Err(anyhow!( + "Error streaming tokens from {}: status={} data={}", + &provider_name, + status, + String::from_utf8_lossy(&b) + ))?; + } + } + } + _ => { + Err(anyhow!( + "Error streaming tokens from {}: {:?}", + &provider_name, + e + ))?; + } + } + break 'stream; + } + } + } + + let mut completion = { + let guard = chunks.lock(); + let f = match guard.len() { + 0 => Err(anyhow!("No chunks received from {}", provider_name)), + _ => Ok(guard[0].clone()), + }?; + + // merge logprobs from all choices of all chunks + let logprobs: Vec = guard + .iter() + .flat_map(|chunk| { + chunk + .choices + .iter() + .filter_map(|choice| choice.logprobs.as_ref().map(|lp| lp.content.clone())) + }) + .flatten() + .collect(); + + let mut c = OpenAIChatCompletion { + id: f.id.clone(), + object: f.object.clone(), + created: f.created, + choices: f + .choices + .iter() + .map(|c| OpenAIChatChoice { + message: OpenAICompletionChatMessage { + content: Some("".to_string()), + name: None, + role: OpenAIChatMessageRole::System, + tool_calls: None, + tool_call_id: None, + }, + index: c.index, + finish_reason: None, + logprobs: match logprobs.len() { + 0 => None, + _ => Some(OpenAIChatChoiceLogprobs { + content: logprobs.clone(), + }), + }, + }) + .collect::>(), + usage, + }; + + for i in 0..guard.len() { + let a = guard[i].clone(); + if a.choices.len() != f.choices.len() { + Err(anyhow!("Inconsistent number of choices in streamed chunks"))?; + } + for j in 0..a.choices.len() { + match a.choices.get(j).unwrap().finish_reason.clone() { + None => (), + Some(f) => c.choices[j].finish_reason = Some(f), + }; + + match a.choices[j].delta.get("role") { + None => (), + Some(role) => match role.as_str() { + None => (), + Some(r) => { + c.choices[j].message.role = OpenAIChatMessageRole::from_str(r)?; + } + }, + }; + + match a.choices[j].delta.get("content") { + None => (), + Some(content) => match content.as_str() { + None => (), + Some(s) => { + c.choices[j].message.content = Some(format!( + "{}{}", + c.choices[j] + .message + .content + .as_ref() + .unwrap_or(&String::new()), + s + )); + } + }, + }; + + if let Some(tool_calls) = a.choices[j] + .delta + .get("tool_calls") + .and_then(|v| v.as_array()) + { + for tool_call in tool_calls { + match ( + tool_call.get("type").and_then(|v| v.as_str()), + tool_call.get("id").and_then(|v| v.as_str()), + tool_call.get("function"), + ) { + (Some("function"), id, Some(f)) => { + if let Some(Value::String(name)) = f.get("name") { + c.choices[j] + .message + .tool_calls + .get_or_insert_with(Vec::new) + .push(OpenAIToolCall { + // Set None if id is empty + id: id.filter(|s| !s.is_empty()).map(|s| s.to_string()), + r#type: OpenAIToolType::Function, + function: OpenAIToolCallFunction { + name: name.clone(), + arguments: match f.get("arguments") { + Some(Value::String(a)) => a.clone(), + _ => String::new(), + }, + }, + }); + } + } + (None, None, Some(f)) => { + if let (Some(Value::Number(idx)), Some(Value::String(a))) = + (tool_call.get("index"), f.get("arguments")) + { + let index: usize = idx + .as_u64() + .ok_or_else(|| anyhow!("Missing index for tools"))? + .try_into() + .map_err(|e| { + anyhow!("Invalid index value for tools: {:?}", e) + })?; + + let tool_calls = c.choices[j] + .message + .tool_calls + .as_mut() + .ok_or(anyhow!("Missing tool calls"))?; + + if index >= tool_calls.len() { + return Err(anyhow!( + "Index out-of-bound for tool_calls: {}", + index + )); + } + + tool_calls[index].function.arguments += a; + } + } + _ => (), + } + } + } + } + } + c + }; + + // for all messages, edit the content and strip leading and trailing spaces and \n + for m in completion.choices.iter_mut() { + m.message.content = match m.message.content.as_ref() { + None => None, + Some(c) => Some(c.trim().to_string()), + }; + } + + Ok((completion, request_id)) +} + +async fn chat_completion( + uri: Uri, + api_key: String, + organization_id: Option, + model_id: Option, + messages: &Vec, + tools: Vec, + tool_choice: Option, + temperature: f32, + top_p: f32, + n: usize, + stop: &Vec, + max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + response_format: Option, + reasoning_effort: Option, + logprobs: Option, + top_logprobs: Option, + user: Option, + provider_name: String, +) -> Result<(OpenAIChatCompletion, Option)> { + let mut body = json!({ + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "n": n, + }); + if let Some(presence_penalty) = presence_penalty { + body["presence_penalty"] = json!(presence_penalty); + } + if let Some(frequency_penalty) = frequency_penalty { + body["frequency_penalty"] = json!(frequency_penalty); + } + if user.is_some() { + body["user"] = json!(user); + } + if let Some(model_id) = model_id { + body["model"] = json!(model_id); + } + if let Some(mt) = max_tokens { + body["max_tokens"] = mt.into(); + } + if !stop.is_empty() { + body["stop"] = json!(stop); + } + + if let Some(response_format) = response_format { + body["response_format"] = json!({ + "type": response_format, + }); + } + if tools.len() > 0 { + body["tools"] = json!(tools); + } + if let Some(tool_choice) = tool_choice { + body["tool_choice"] = json!(tool_choice); + } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } + if let Some(logprobs) = logprobs { + body["logprobs"] = json!(logprobs); + } + if let Some(top_logprobs) = top_logprobs { + body["top_logprobs"] = json!(top_logprobs); + } + + let mut req = reqwest::Client::new() + .post(uri.to_string()) + .header("Content-Type", "application/json") + // This one is for `openai`. + .header("Authorization", format!("Bearer {}", api_key.clone())) + // This one is for `azure_openai`. + .header("api-key", api_key.clone()); + + if let Some(organization_id) = organization_id { + req = req.header( + "OpenAI-Organization", + &format!("{}", organization_id.clone()), + ); + } + + let req = req.json(&body); + + let res = match timeout(Duration::new(180, 0), req.send()).await { + Ok(Ok(res)) => res, + Ok(Err(e)) => Err(e)?, + Err(_) => Err(anyhow!(format!( + "Timeout sending request to {} after 180s", + provider_name + )))?, + }; + + let res_headers = res.headers(); + let request_id = match res_headers.get("x-request-id") { + Some(request_id) => Some(request_id.to_str()?.to_string()), + None => None, + }; + + let body = match timeout(Duration::new(180, 0), res.bytes()).await { + Ok(Ok(body)) => body, + Ok(Err(e)) => Err(e)?, + Err(_) => Err(anyhow!(format!( + "Timeout reading response from {} after 180s", + provider_name + )))?, + }; + + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let mut completion: OpenAIChatCompletion = match serde_json::from_slice(c) { + Ok(c) => Ok(c), + Err(_) => { + let error: OpenAIError = serde_json::from_slice(c)?; + match error.retryable() { + true => Err(ModelError { + request_id: request_id.clone(), + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 2, + retries: 3, + }), + }), + false => Err(ModelError { + request_id: request_id.clone(), + message: error.with_provider(&provider_name).message(), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), + }), + } + } + }?; + + // for all messages, edit the content and strip leading and trailing spaces and \n + for m in completion.choices.iter_mut() { + m.message.content = match m.message.content.as_ref() { + None => None, + Some(c) => Some(c.trim().to_string()), + }; + } + + Ok((completion, request_id)) +} + +fn logprobs_from_choices(choices: &Vec) -> Option> { + let lp: Vec = choices + .iter() + .filter_map(|choice| choice.logprobs.as_ref()) + .flat_map(|lp| { + lp.content.iter().map(|content_logprob| LLMChatLogprob { + token: content_logprob.token.clone(), + logprob: content_logprob.logprob, + top_logprobs: match content_logprob.top_logprobs.len() { + 0 => None, + _ => Some( + content_logprob + .top_logprobs + .iter() + .map(|top_logprob| top_logprob.clone().into()) + .collect(), + ), + }, + }) + }) + .collect(); + + match lp.len() { + 0 => None, + _ => Some(lp), + } +} diff --git a/core/src/providers/togetherai.rs b/core/src/providers/togetherai.rs index 7a3911da98f7..8f77ed9e0186 100644 --- a/core/src/providers/togetherai.rs +++ b/core/src/providers/togetherai.rs @@ -1,12 +1,7 @@ -use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use crate::providers::chat_messages::ChatMessage; use crate::providers::embedder::Embedder; use crate::providers::llm::ChatFunction; -use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; -use crate::providers::openai::{ - chat_completion, logprobs_from_choices, streamed_chat_completion, to_openai_messages, - OpenAIChatMessage, OpenAIChatMessageContent, OpenAIContentBlock, OpenAITextContent, - OpenAITextContentType, OpenAITool, OpenAIToolChoice, -}; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLM}; use crate::providers::provider::{Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; @@ -18,10 +13,13 @@ use async_trait::async_trait; use hyper::Uri; use parking_lot::RwLock; use serde_json::Value; -use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::openai_compatible_helpers::{ + openai_compatible_chat_completion, TransformSystemMessages, +}; + pub struct TogetherAILLM { id: String, api_key: Option, @@ -115,7 +113,7 @@ impl LLM for TogetherAILLM { top_p: Option, n: usize, stop: &Vec, - mut max_tokens: Option, + max_tokens: Option, presence_penalty: Option, frequency_penalty: Option, logprobs: Option, @@ -123,138 +121,30 @@ impl LLM for TogetherAILLM { _extras: Option, event_sender: Option>, ) -> Result { - if let Some(m) = max_tokens { - if m == -1 { - max_tokens = None; - } - } - - let tool_choice = match function_call.as_ref() { - Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), - None => None, - }; - - let tools = functions - .iter() - .map(OpenAITool::try_from) - .collect::, _>>()?; - - // TogetherAI doesn't work with the new chat message content format. - // We have to modify the messages contents to use the "String" format. - let openai_messages = to_openai_messages(messages, &self.id)? - .into_iter() - .filter_map(|m| match m.content { - None => Some(m), - Some(OpenAIChatMessageContent::String(_)) => Some(m), - Some(OpenAIChatMessageContent::Structured(contents)) => { - // Find the first text content, and use it to make a string content. - let content = contents.into_iter().find_map(|c| match c { - OpenAIContentBlock::TextContent(OpenAITextContent { - r#type: OpenAITextContentType::Text, - text, - .. - }) => Some(OpenAIChatMessageContent::String(text)), - _ => None, - }); - - Some(OpenAIChatMessage { - role: m.role, - name: m.name, - tool_call_id: m.tool_call_id, - tool_calls: m.tool_calls, - content, - }) - } - }) - .collect::>(); - - let is_streaming = event_sender.is_some(); - - let (c, request_id) = if is_streaming { - streamed_chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - event_sender.clone(), - ) - .await? - } else { - chat_completion( - self.chat_uri()?, - self.api_key.clone().unwrap(), - None, - Some(self.id.clone()), - &openai_messages, - tools, - tool_choice, - temperature, - match top_p { - Some(t) => t, - None => 1.0, - }, - n, - stop, - max_tokens, - match presence_penalty { - Some(p) => p, - None => 0.0, - }, - match frequency_penalty { - Some(f) => f, - None => 0.0, - }, - None, - None, - logprobs, - top_logprobs, - None, - ) - .await? - }; - - assert!(c.choices.len() > 0); - - Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::OpenAI.to_string(), - model: self.id.clone(), - completions: c - .choices - .iter() - .map(|c| AssistantChatMessage::try_from(&c.message)) - .collect::>>()?, - usage: c.usage.map(|usage| LLMTokenUsage { - prompt_tokens: usage.prompt_tokens, - completion_tokens: usage.completion_tokens.unwrap_or(0), - }), - provider_request_id: request_id, - logprobs: logprobs_from_choices(&c.choices), - }) + openai_compatible_chat_completion( + self.chat_uri()?, + self.id.clone(), + self.api_key.clone().unwrap(), + messages, + functions, + function_call, + temperature, + top_p, + n, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logprobs, + top_logprobs, + None, + event_sender, + false, // don't disable provider streaming + TransformSystemMessages::Keep, + "TogetherAI".to_string(), + true, // squash text contents (togetherai doesn't support structured messages) + ) + .await } } diff --git a/core/src/search_stores/search_store.rs b/core/src/search_stores/search_store.rs index fe716e373d53..e23fc6fe3fff 100644 --- a/core/src/search_stores/search_store.rs +++ b/core/src/search_stores/search_store.rs @@ -110,6 +110,13 @@ impl SearchStore for ElasticsearchSearchStore { )); } + // check there is at least one data source view filter + // !! do not remove; without data source view filter this endpoint is + // dangerous as any data from any workspace can be retrieved + if filter.data_source_views.is_empty() { + return Err(anyhow::anyhow!("No data source views provided")); + } + // Build filter conditions using elasticsearch-dsl let filter_conditions: Vec = filter .data_source_views diff --git a/extension/package-lock.json b/extension/package-lock.json index 7b2297218462..b257246b0d46 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -10,7 +10,7 @@ "license": "ISC", "dependencies": { "@dust-tt/client": "^1.0.24", - "@dust-tt/sparkle": "^0.2.365", + "@dust-tt/sparkle": "^0.2.367", "@tailwindcss/forms": "^0.5.9", "@tiptap/extension-character-count": "^2.9.1", "@tiptap/extension-mention": "^2.9.1", @@ -262,9 +262,9 @@ } }, "node_modules/@dust-tt/sparkle": { - "version": "0.2.365", - "resolved": "https://registry.npmjs.org/@dust-tt/sparkle/-/sparkle-0.2.365.tgz", - "integrity": "sha512-9IKPOaWrcj5XbOXHK9m7yi3uzzsSQKvNkDc+IsOOZbE4HvnmCavWry+W6iFPMZdxCsFUHi4kJsX/rK1/mbrtog==", + "version": "0.2.367", + "resolved": "https://registry.npmjs.org/@dust-tt/sparkle/-/sparkle-0.2.367.tgz", + "integrity": "sha512-nHXvJJPrrhHr0B0bKLoD5aKEheV/miE3XW9SUjhuqSXRUcFoqxtcLSksDYC9+0u9y+Y7ucb+ppD8/LZ0nzK4qQ==", "dependencies": { "@emoji-mart/data": "^1.1.2", "@emoji-mart/react": "^1.1.1", @@ -292,6 +292,7 @@ "react-dropdown-menu": "^0.0.2", "react-katex": "^3.0.1", "react-markdown": "^8.0.7", + "react-resizable-panels": "^2.1.7", "react-syntax-highlighter": "^15.6.1", "rehype-katex": "^7.0.1", "remark-directive": "^2.0.1", @@ -10141,6 +10142,15 @@ } } }, + "node_modules/react-resizable-panels": { + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/react-resizable-panels/-/react-resizable-panels-2.1.7.tgz", + "integrity": "sha512-JtT6gI+nURzhMYQYsx8DKkx6bSoOGFp7A3CwMrOb8y5jFHFyqwo9m68UhmXRw57fRVJksFn1TSlm3ywEQ9vMgA==", + "peerDependencies": { + "react": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc", + "react-dom": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + } + }, "node_modules/react-router": { "version": "6.26.2", "resolved": "https://registry.npmjs.org/react-router/-/react-router-6.26.2.tgz", diff --git a/extension/package.json b/extension/package.json index e837d032f20c..c6346a11e4a7 100644 --- a/extension/package.json +++ b/extension/package.json @@ -50,7 +50,7 @@ }, "dependencies": { "@dust-tt/client": "^1.0.24", - "@dust-tt/sparkle": "^0.2.365", + "@dust-tt/sparkle": "^0.2.367", "@tailwindcss/forms": "^0.5.9", "@tiptap/extension-character-count": "^2.9.1", "@tiptap/extension-mention": "^2.9.1", diff --git a/front/admin/cli.ts b/front/admin/cli.ts index 0f290d8ddbc1..aa1e47c0d4ad 100644 --- a/front/admin/cli.ts +++ b/front/admin/cli.ts @@ -1,3 +1,12 @@ +import { DustAPI } from "@dust-tt/client"; +import { + assertNever, + ConnectorsAPI, + removeNulls, + SUPPORTED_MODEL_CONFIGS, +} from "@dust-tt/types"; +import parseArgs from "minimist"; + import { getConversation } from "@app/lib/api/assistant/conversation"; import { renderConversationForModel } from "@app/lib/api/assistant/generation"; import { getTextRepresentationFromMessages } from "@app/lib/api/assistant/utils"; @@ -11,7 +20,7 @@ import { internalSubscribeWorkspaceToFreeNoPlan, internalSubscribeWorkspaceToFreePlan, } from "@app/lib/plans/subscription"; -import { DustProdActionRegistry } from "@app/lib/registry"; +import { getDustProdActionRegistry } from "@app/lib/registry"; import { DataSourceResource } from "@app/lib/resources/data_source_resource"; import { GroupResource } from "@app/lib/resources/group_resource"; import { LabsTranscriptsConfigurationResource } from "@app/lib/resources/labs_transcripts_resource"; @@ -27,14 +36,6 @@ import { stopRetrieveTranscriptsWorkflow, } from "@app/temporal/labs/client"; import { REGISTERED_CHECKS } from "@app/temporal/production_checks/activities"; -import { DustAPI } from "@dust-tt/client"; -import { - assertNever, - ConnectorsAPI, - removeNulls, - SUPPORTED_MODEL_CONFIGS, -} from "@dust-tt/types"; -import parseArgs from "minimist"; // `cli` takes an object type and a command as first two arguments and then a list of arguments. const workspace = async (command: string, args: parseArgs.ParsedArgs) => { @@ -63,7 +64,6 @@ const workspace = async (command: string, args: parseArgs.ParsedArgs) => { }); args.wId = w.sId; - await workspace("show", args); return; } @@ -193,7 +193,7 @@ const workspace = async (command: string, args: parseArgs.ParsedArgs) => { default: console.log(`Unknown workspace command: ${command}`); console.log( - "Possible values: `find`, `show`, `create`, `set-limits`, `upgrade`, `downgrade`" + "Possible values: `find`, `create`, `set-limits`, `upgrade`, `downgrade`" ); } }; @@ -445,7 +445,7 @@ const transcripts = async (command: string, args: parseArgs.ParsedArgs) => { const registry = async (command: string) => { switch (command) { case "dump": { - console.log(JSON.stringify(DustProdActionRegistry)); + console.log(JSON.stringify(getDustProdActionRegistry())); return; } @@ -507,7 +507,7 @@ const productionCheck = async (command: string, args: parseArgs.ParsedArgs) => { args.url ); - const actions = Object.values(DustProdActionRegistry); + const actions = Object.values(getDustProdActionRegistry()); const res = await api.checkApps( { diff --git a/front/admin/copy_apps.sh b/front/admin/copy_apps.sh index 95d721b539bf..d72a9e59839e 100755 --- a/front/admin/copy_apps.sh +++ b/front/admin/copy_apps.sh @@ -51,7 +51,8 @@ function import { } DEVELOPMENT_DUST_APPS_WORKSPACE_ID='78bda07b39' -DUST_APPS_WORKSPACE_NUMERIC_ID=$(npx tsx ${DIR}/init_dust_apps.ts) +npx tsx ${DIR}/init_dust_apps.ts --name dust-apps +DUST_APPS_WORKSPACE_NUMERIC_ID=5069 mkdir -p /tmp/dust-apps @@ -75,7 +76,7 @@ then # Get projects matching the current specifications PROJECTS=$(psql $CORE_DATABASE_URI -c "copy (select distinct(project) from specifications where hash in (${IN_CLAUSE})) to stdout" | sed "s/.*/'&'/" | paste -sd, -) # Get appIds matching the specifications - LOCAL_APP_IDS=$(psql $FRONT_DATABASE_URI -c "copy (select distinct(\"sId\") from apps where \"dustAPIProjectId\" in (${PROJECTS}) and visibility!='deleted' and \"workspaceId\"=${DUST_APPS_WORKSPACE_NUMERIC_ID} order by \"sId\") to stdout" | paste -sd\ -) + LOCAL_APP_IDS=$(psql $FRONT_DATABASE_URI -c "copy (select distinct(\"sId\") from apps where \"dustAPIProjectId\" in (${PROJECTS}) and \"deletedAt\" is null and \"workspaceId\"=${DUST_APPS_WORKSPACE_NUMERIC_ID} order by \"sId\") to stdout" | paste -sd\ -) # Check if any app is missing MISSING=false @@ -108,7 +109,7 @@ PRODBOX_POD_NAME=$(kubectl get pods |grep prodbox|grep Running |cut -d \ -f1) # ---- front VAULT_ID=$(psql ${FRONT_DATABASE_URI} -c "COPY (SELECT id from vaults where \"workspaceId\"=${DUST_APPS_WORKSPACE_NUMERIC_ID} and name='Public Dust Apps') TO STDOUT") -fetch FRONT apps "id createdAt updatedAt sId name description visibility savedSpecification savedConfig savedRun dustAPIProjectId ${DUST_APPS_WORKSPACE_NUMERIC_ID} ${VAULT_ID}" "\\\"workspaceId\\\"=5069" +fetch FRONT apps "id createdAt updatedAt sId name description visibility savedSpecification savedConfig savedRun dustAPIProjectId ${DUST_APPS_WORKSPACE_NUMERIC_ID} ${VAULT_ID}" "\\\"workspaceId\\\"=5069 AND \\\"vaultId\\\"=93077 and \\\"deletedAt\\\" is null" PROJECT_IDS=$(cut -f 11 /tmp/dust-apps/FRONT_apps.csv |paste -sd "," -) fetch FRONT datasets "id createdAt updatedAt name description schema appId ${DUST_APPS_WORKSPACE_NUMERIC_ID}" "\\\"workspaceId\\\"=5069" @@ -116,7 +117,7 @@ fetch FRONT datasets "id createdAt updatedAt name description schema appId ${DUS import FRONT apps "id createdAt updatedAt sId name description visibility savedSpecification savedConfig savedRun dustAPIProjectId workspaceId vaultId" "updatedAt name description visibility savedSpecification savedConfig savedRun dustAPIProjectId" # ---- datasets -import FRONT datasets "id createdAt updatedAt name description schema appId workspaceId" "updatedAt name description schema" +import FRONT datasets "id createdAt updatedAt name description schema appId workspaceId" "updatedAt name description schema" "" "and __copy.\"appId\" in (select \"id\" from apps)" # ---- core diff --git a/front/admin/init_dust_apps.ts b/front/admin/init_dust_apps.ts index 6e32d2a0e989..48b0e9fa9628 100644 --- a/front/admin/init_dust_apps.ts +++ b/front/admin/init_dust_apps.ts @@ -1,3 +1,7 @@ +import { concurrentExecutor, isDevelopment } from "@dust-tt/types"; +import _ from "lodash"; +import parseArgs from "minimist"; + import { Authenticator } from "@app/lib/auth"; import { Workspace } from "@app/lib/models/workspace"; import { internalSubscribeWorkspaceToFreePlan } from "@app/lib/plans/subscription"; @@ -5,22 +9,34 @@ import { GroupResource } from "@app/lib/resources/group_resource"; import { MembershipResource } from "@app/lib/resources/membership_resource"; import { SpaceResource } from "@app/lib/resources/space_resource"; import { UserModel } from "@app/lib/resources/storage/models/user"; +import { generateRandomModelSId } from "@app/lib/resources/string_ids"; import { UserResource } from "@app/lib/resources/user_resource"; import { renderLightWorkspaceType } from "@app/lib/workspace"; +import logger from "@app/logger/logger"; + +const DEFAULT_WORKSPACE_NAME = "dust-apps"; +const DEFAULT_SPACE_NAME = "Public Dust Apps"; async function main() { - let w = await Workspace.findOne({ where: { sId: "78bda07b39" } }); - if (w) { - console.log(w.id); - process.exit(0); - } + const argv = parseArgs(process.argv.slice(2)); - w = await Workspace.create({ - id: 5069, - sId: "78bda07b39", - name: "dust-apps", - }); + const where = _.pick(argv, ["name", "sId"]); + if (!where.name && !where.sId) { + throw new Error("Please provide name and/or sId for the workspace"); + } + let w = await Workspace.findOne({ where }); + if (!w) { + console.log("Creating workspace"); + w = await Workspace.create({ + sId: argv.sId || generateRandomModelSId(), + name: argv.name || DEFAULT_WORKSPACE_NAME, + }); + await internalSubscribeWorkspaceToFreePlan({ + workspaceId: w.sId, + planCode: "FREE_UPGRADED_PLAN", + }); + } const lightWorkspace = renderLightWorkspaceType({ workspace: w }); const { systemGroup, globalGroup } = @@ -34,34 +50,39 @@ async function main() { globalGroup, }); - const group = await GroupResource.makeNew({ - name: `Group for space Public Dust Apps`, - workspaceId: w.id, - kind: "regular", - }); - - await SpaceResource.makeNew( - { id: 93077, name: "Public Dust Apps", kind: "public", workspaceId: w.id }, - [group] - ); + const spaces = await SpaceResource.listWorkspaceSpaces(auth); + let space = spaces.find((s) => s.isPublic()); + if (!space) { + console.log("Creating group"); + const group = await GroupResource.makeNew({ + name: `Group for space ${DEFAULT_SPACE_NAME}`, + workspaceId: w.id, + kind: "regular", + }); - await internalSubscribeWorkspaceToFreePlan({ - workspaceId: w.sId, - planCode: "FREE_UPGRADED_PLAN", - }); + if (isDevelopment()) { + const users = await UserModel.findAll(); + await concurrentExecutor( + users, + async (user) => + MembershipResource.createMembership({ + user: new UserResource(UserModel, user.get()), + workspace: lightWorkspace, + role: "admin", + }), + { concurrency: 5 } + ); + } - const users = await UserModel.findAll(); - await Promise.all( - users.map(async (user) => - MembershipResource.createMembership({ - user: new UserResource(UserModel, user.get()), - workspace: lightWorkspace, - role: "admin", - }) - ) - ); + console.log("Creating space"); + space = await SpaceResource.makeNew( + { name: DEFAULT_SPACE_NAME, kind: "public", workspaceId: w.id }, + [group] + ); + } - console.log(w.id); + console.log(`export DUST_APPS_WORKSPACE_ID=${w.sId}`); + console.log(`export DUST_APPS_SPACE_ID=${space.sId}`); } main() diff --git a/front/components/UserMenu.tsx b/front/components/UserMenu.tsx index ffa00a5b7632..e263880bd45c 100644 --- a/front/components/UserMenu.tsx +++ b/front/components/UserMenu.tsx @@ -41,7 +41,7 @@ export function UserMenu({ const forceRoleUpdate = useMemo( () => async (role: "user" | "builder" | "admin") => { - const result = await forceUserRole(user, owner, role); + const result = await forceUserRole(user, owner, role, featureFlags); if (result.isOk()) { sendNotification({ title: "Success !", @@ -59,7 +59,7 @@ export function UserMenu({ }); } }, - [owner, sendNotification, user] + [owner, sendNotification, user, featureFlags] ); return ( @@ -103,7 +103,7 @@ export function UserMenu({ )} - {showDebugTools(owner) && ( + {showDebugTools(featureFlags) && ( <> {router.route === "/w/[wId]/assistant/[cId]" && ( diff --git a/front/components/assistant/conversation/FeedbackSelector.tsx b/front/components/assistant/conversation/FeedbackSelector.tsx index cad4db181653..2edcb0de9685 100644 --- a/front/components/assistant/conversation/FeedbackSelector.tsx +++ b/front/components/assistant/conversation/FeedbackSelector.tsx @@ -1,11 +1,15 @@ -import { Button } from "@dust-tt/sparkle"; -import { Checkbox } from "@dust-tt/sparkle"; -import { Page } from "@dust-tt/sparkle"; -import { PopoverContent, PopoverRoot, PopoverTrigger } from "@dust-tt/sparkle"; -import { Spinner } from "@dust-tt/sparkle"; -import { TextArea } from "@dust-tt/sparkle"; -import { Tooltip } from "@dust-tt/sparkle"; -import { HandThumbDownIcon, HandThumbUpIcon } from "@dust-tt/sparkle"; +import { + Button, + Checkbox, + HandThumbDownIcon, + HandThumbUpIcon, + Page, + PopoverContent, + PopoverRoot, + PopoverTrigger, + Spinner, + TextArea, +} from "@dust-tt/sparkle"; import React, { useCallback, useEffect, useRef } from "react"; export type ThumbReaction = "up" | "down"; @@ -130,38 +134,30 @@ export function FeedbackSelector({
- Not saving the reaction until then. - className={ - feedback?.thumb === "up" ? "" : "text-muted-foreground" - } - /> +
diff --git a/front/components/poke/plugins/RunPluginDialog.tsx b/front/components/poke/plugins/RunPluginDialog.tsx index c308bc09d8ee..6896e5d8f91a 100644 --- a/front/components/poke/plugins/RunPluginDialog.tsx +++ b/front/components/poke/plugins/RunPluginDialog.tsx @@ -1,4 +1,12 @@ -import { Spinner } from "@dust-tt/sparkle"; +import { + NewDialog, + NewDialogContainer, + NewDialogContent, + NewDialogDescription, + NewDialogHeader, + NewDialogTitle, + Spinner, +} from "@dust-tt/sparkle"; import type { PluginWorkspaceResource } from "@dust-tt/types"; import { AlertCircle } from "lucide-react"; import { useCallback, useState } from "react"; @@ -9,10 +17,6 @@ import { PokeAlertDescription, PokeAlertTitle, } from "@app/components/poke/shadcn/ui/alert"; -import { - PokeDialog, - PokeDialogContent, -} from "@app/components/poke/shadcn/ui/dialog"; import type { PluginListItem, PluginResponse } from "@app/lib/api/poke/types"; import { usePokePluginManifest, useRunPokePlugin } from "@app/poke/swr/plugins"; @@ -63,64 +67,66 @@ export function RunPluginDialog({ ); return ( - - -
-

Run {plugin.name} plugin

-

- {plugin.description} -

-
- {isLoading ? ( - - ) : !manifest ? ( - - - Error - - Plugin could not be loaded. - - - ) : ( - <> - {error && ( - - Error - {error} - - )} - {result && result.display === "text" && ( - - Success - - {result.value} - Make sure to reload. - - - )} - {result && result.display === "json" && ( -
-
Result:
-
-
-                    {JSON.stringify(result.value, null, 2)}
-                  
+ + + + Run {plugin.name} plugin + {plugin.description} + + + {isLoading ? ( + + ) : !manifest ? ( + + + Error + + Plugin could not be loaded. + + + ) : ( + <> + {error && ( + + Error + {error} + + )} + {result && result.display === "text" && ( + + Success + + {result.value} - Make sure to reload. + + + )} + {result && result.display === "json" && ( +
+
Result:
+
+
+                      {JSON.stringify(result.value, null, 2)}
+                    
+
-
- )} - - {manifest.warning && ( - - Warning - {manifest.warning} - - )} - - )} - - + )} + + {manifest.warning && ( + + Warning + + {manifest.warning} + + + )} + + )} + + + ); } diff --git a/front/components/poke/shadcn/ui/command.tsx b/front/components/poke/shadcn/ui/command.tsx index 5f131d058fe4..423d95773334 100644 --- a/front/components/poke/shadcn/ui/command.tsx +++ b/front/components/poke/shadcn/ui/command.tsx @@ -1,16 +1,16 @@ "use client"; -import { MagnifyingGlassIcon } from "@dust-tt/sparkle"; +import { + MagnifyingGlassIcon, + NewDialog, + NewDialogContent, +} from "@dust-tt/sparkle"; import type { DialogProps } from "@radix-ui/react-dialog"; import { Command as CommandPrimitive } from "cmdk"; import Link from "next/link"; import * as React from "react"; import { cn } from "@app/components/poke/shadcn/lib/utils"; -import { - PokeDialog, - PokeDialogContent, -} from "@app/components/poke/shadcn/ui/dialog"; const CommandContext = React.createContext<{ selectedIndex: number; @@ -117,16 +117,16 @@ const CommandDialog = ({ return ( - - + + {children} - - + + ); }; diff --git a/front/components/poke/shadcn/ui/dialog.tsx b/front/components/poke/shadcn/ui/dialog.tsx deleted file mode 100644 index 2bf3974be804..000000000000 --- a/front/components/poke/shadcn/ui/dialog.tsx +++ /dev/null @@ -1,120 +0,0 @@ -import * as DialogPrimitive from "@radix-ui/react-dialog"; -import { XIcon } from "lucide-react"; -import * as React from "react"; - -import { cn } from "@app/components/poke/shadcn/lib/utils"; - -const Dialog = DialogPrimitive.Root; - -const DialogTrigger = DialogPrimitive.Trigger; - -const DialogPortal = DialogPrimitive.Portal; - -const DialogClose = DialogPrimitive.Close; - -const DialogOverlay = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)); -DialogOverlay.displayName = DialogPrimitive.Overlay.displayName; - -const DialogContent = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, children, ...props }, ref) => ( - - - - {children} - - - Close - - - -)); -DialogContent.displayName = DialogPrimitive.Content.displayName; - -const DialogHeader = ({ - className, - ...props -}: React.HTMLAttributes) => ( -
-); -DialogHeader.displayName = "DialogHeader"; - -const DialogFooter = ({ - className, - ...props -}: React.HTMLAttributes) => ( -
-); -DialogFooter.displayName = "DialogFooter"; - -const DialogTitle = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)); -DialogTitle.displayName = DialogPrimitive.Title.displayName; - -const DialogDescription = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)); -DialogDescription.displayName = DialogPrimitive.Description.displayName; - -export { - Dialog as PokeDialog, - DialogClose as PokeDialogClose, - DialogContent as PokeDialogContent, - DialogDescription as PokeDialogDescription, - DialogFooter as PokeDialogFooter, - DialogHeader as PokeDialogHeader, - DialogOverlay as PokeDialogOverlay, - DialogPortal as PokeDialogPortal, - DialogTitle as PokeDialogTitle, - DialogTrigger as PokeDialogTrigger, -}; diff --git a/front/components/poke/subscriptions/EnterpriseUpgradeDialog.tsx b/front/components/poke/subscriptions/EnterpriseUpgradeDialog.tsx index a861171b3ca3..acde59fb5d3a 100644 --- a/front/components/poke/subscriptions/EnterpriseUpgradeDialog.tsx +++ b/front/components/poke/subscriptions/EnterpriseUpgradeDialog.tsx @@ -1,4 +1,13 @@ -import { Spinner } from "@dust-tt/sparkle"; +import { + NewDialog, + NewDialogContent, + NewDialogDescription, + NewDialogFooter, + NewDialogHeader, + NewDialogTitle, + NewDialogTrigger, + Spinner, +} from "@dust-tt/sparkle"; import type { EnterpriseUpgradeFormType, WorkspaceType } from "@dust-tt/types"; import { EnterpriseUpgradeFormSchema, removeNulls } from "@dust-tt/types"; import { ioTsResolver } from "@hookform/resolvers/io-ts"; @@ -7,15 +16,6 @@ import { useState } from "react"; import { useForm } from "react-hook-form"; import { PokeButton } from "@app/components/poke/shadcn/ui/button"; -import { - PokeDialog, - PokeDialogContent, - PokeDialogDescription, - PokeDialogFooter, - PokeDialogHeader, - PokeDialogTitle, - PokeDialogTrigger, -} from "@app/components/poke/shadcn/ui/dialog"; import { PokeForm } from "@app/components/poke/shadcn/ui/form"; import { InputField, @@ -95,18 +95,18 @@ export default function EnterpriseUpgradeDialog({ }; return ( - - + + 🏢 Upgrade to Enterprise - - - - Upgrade {owner.name} to Enterprise. - + + + + Upgrade {owner.name} to Enterprise. + Select the enterprise plan and provide the Stripe subscription id of the customer. - - + + {error &&
{error}
} {isSubmitting && } {!isSubmitting && ( @@ -135,18 +135,18 @@ export default function EnterpriseUpgradeDialog({ />
- + Upgrade - + )} - - + + ); } diff --git a/front/components/spaces/AddToSpaceDialog.tsx b/front/components/spaces/AddToSpaceDialog.tsx index 6e721ba6f9af..9a584e2f9dc7 100644 --- a/front/components/spaces/AddToSpaceDialog.tsx +++ b/front/components/spaces/AddToSpaceDialog.tsx @@ -1,14 +1,19 @@ import { Button, - Dialog, DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, + NewDialog, + NewDialogContainer, + NewDialogContent, + NewDialogFooter, + NewDialogHeader, + NewDialogTitle, ScrollArea, ScrollBar, + useSendNotification, } from "@dust-tt/sparkle"; -import { useSendNotification } from "@dust-tt/sparkle"; import type { APIError, DataSourceViewContentNode, @@ -16,7 +21,7 @@ import type { LightWorkspaceType, SpaceType, } from "@dust-tt/types"; -import { useEffect, useState } from "react"; +import React, { useEffect, useState } from "react"; import { useDataSourceViews } from "@app/lib/swr/data_source_views"; import { useSpaces } from "@app/lib/swr/spaces"; @@ -133,45 +138,64 @@ export const AddToSpaceDialog = ({ }; return ( - onClose(false)} - onValidate={addToSpace} - title="Add to Space" - validateLabel="Save" + { + if (!open) { + onClose(false); + } + }} > - {availableSpaces.length === 0 ? ( -
- This data is already available in all spaces. -
- ) : ( - - -
+ + + + {availableSpaces.map((currentSpace) => ( + setSpace(currentSpace)} + /> + ))} + + + + + )} + + + + ); }; diff --git a/front/components/spaces/EditSpaceManagedDatasourcesViews.tsx b/front/components/spaces/EditSpaceManagedDatasourcesViews.tsx index c79daa6f069c..56ba4ae219a7 100644 --- a/front/components/spaces/EditSpaceManagedDatasourcesViews.tsx +++ b/front/components/spaces/EditSpaceManagedDatasourcesViews.tsx @@ -1,12 +1,17 @@ import { Button, ContentMessage, - Dialog, InformationCircleIcon, + NewDialog, + NewDialogContainer, + NewDialogContent, + NewDialogFooter, + NewDialogHeader, + NewDialogTitle, PlusIcon, Tooltip, + useSendNotification, } from "@dust-tt/sparkle"; -import { useSendNotification } from "@dust-tt/sparkle"; import type { APIError, DataSourceViewSelectionConfigurations, @@ -120,32 +125,34 @@ export function EditSpaceManagedDataSourcesViews({ alertDialog: true, children: (
-

The following data sources are currently in use:

- - {deletedViewsWithUsage.map((view) => ( -

- {getDisplayNameForDataSource(view.dataSource)}{" "} - - (used by {view.usage.count} assistant - {view.usage.count > 1 ? "s" : ""}) - -

- ))} - -

- Deleting these data sources will affect the assistants using - them. These assistants will no longer have access to this data - and may not work as expected. -

+ Deleting these data sources will affect the assistants using them. + These assistants will no longer have access to this data and may + not work as expected.
-

Are you sure you want to remove them?

+
+ The following data sources are currently in use: +
    + {deletedViewsWithUsage.map((view) => ( +
  • + {getDisplayNameForDataSource(view.dataSource)}{" "} + + (used by {view.usage.count} assistant + {view.usage.count > 1 ? "s" : ""}) + +
  • + ))} +
+
+
+ Are you sure you want to remove them? +
), }); @@ -271,6 +278,16 @@ export function EditSpaceManagedDataSourcesViews({ return false; } + function handleCloseDataSourcesModal() { + setShowDataSourcesModal(false); + } + + function handleGoToConnectionsManagement() { + void router.push( + `/w/${owner?.sId}/spaces/${systemSpace?.sId}/categories/managed` + ); + } + const addToSpaceButton = (
-
+ + { + await handleGenerate({ + name: newApiKeyName, + group: newApiKeyGroup, + }); + }, + }} + /> + +
    diff --git a/front/pages/w/[wId]/developers/dev-secrets.tsx b/front/pages/w/[wId]/developers/dev-secrets.tsx index 6f1b511b1d08..135cdffd91bb 100644 --- a/front/pages/w/[wId]/developers/dev-secrets.tsx +++ b/front/pages/w/[wId]/developers/dev-secrets.tsx @@ -2,12 +2,17 @@ import { BookOpenIcon, BracesIcon, Button, - Dialog, Input, + NewDialog, + NewDialogContainer, + NewDialogContent, + NewDialogFooter, + NewDialogHeader, + NewDialogTitle, Page, PlusIcon, + useSendNotification, } from "@dust-tt/sparkle"; -import { useSendNotification } from "@dust-tt/sparkle"; import type { DustAppSecretType, SubscriptionType, @@ -122,53 +127,94 @@ export default function SecretsPage({ return ( <> {secretToRevoke ? ( - handleRevoke(secretToRevoke)} - onCancel={() => setSecretToRevoke(null)} + { + if (!open) { + setSecretToRevoke(null); + } + }} > -

    - Are you sure you want to delete the secret{" "} - {secretToRevoke?.name}? -

    -
    + + + Delete {secretToRevoke?.name} + + + Are you sure you want to delete the secret{" "} + {secretToRevoke?.name}? + + setSecretToRevoke(null), + }} + rightButtonProps={{ + label: "Delete", + variant: "warning", + onClick: () => handleRevoke(secretToRevoke), + }} + /> + + ) : null} - handleGenerate(newDustAppSecret)} - onCancel={() => setIsNewSecretPromptOpen(false)} - > - - setNewDustAppSecret({ - ...newDustAppSecret, - name: cleanSecretName(e.target.value), - }) + { + if (!open) { + setIsNewSecretPromptOpen(false); } - /> -

    - Secret names must be alphanumeric and underscore characters only. -

    -
    + }} + > + + + + {isInputNameDisabled ? "Update" : "New"} Developer Secret + + + + + setNewDustAppSecret({ + ...newDustAppSecret, + name: cleanSecretName(e.target.value), + }) + } + /> + + setNewDustAppSecret({ + ...newDustAppSecret, + value: e.target.value, + }) + } + /> +

    +
    + setIsNewSecretPromptOpen(false), + }} + rightButtonProps={{ + label: isInputNameDisabled ? "Update" : "Create", + variant: "primary", + onClick: () => handleGenerate(newDustAppSecret), + }} + /> +
    +
    - - setNewDustAppSecret({ ...newDustAppSecret, value: e.target.value }) - } - /> -

    - Secret values are encrypted and stored securely in our database. -

    -
    (async (context, auth) => { const owner = auth.getNonNullableWorkspace(); @@ -118,6 +121,14 @@ export const getServerSideProps = withDefaultUserAuthRequirements< } } + const isDustAppsSpace = + owner.sId === config.getDustAppsWorkspaceId() && + space.sId === config.getDustAppsSpaceId(); + + const registryApps = isDustAppsSpace + ? Object.values(getDustProdActionRegistry()).map((action) => action.app) + : undefined; + return { props: { category: context.query.category as DataSourceViewCategory, @@ -130,6 +141,7 @@ export const getServerSideProps = withDefaultUserAuthRequirements< space: space.toJSON(), systemSpace: systemSpace.toJSON(), integrations, + registryApps, }, }; }); @@ -143,6 +155,7 @@ export default function Space({ space, systemSpace, integrations, + registryApps, }: InferGetServerSidePropsType) { const router = useRouter(); return ( @@ -179,6 +192,7 @@ export default function Space({ onSelect={(sId) => { void router.push(`/w/${owner.sId}/spaces/${space.sId}/apps/${sId}`); }} + registryApps={registryApps} /> ) : ( (); export type WhitelistableFeature = z.infer; @@ -2145,6 +2146,8 @@ export type UpsertFolderResponseType = z.infer< typeof UpsertFolderResponseSchema >; +const ProviderVisibilitySchema = FlexibleEnumSchema<"public" | "private">(); + export const UpsertDataSourceFolderRequestSchema = z.object({ timestamp: z.number(), parents: z.array(z.string()).nullable().optional(), @@ -2152,6 +2155,7 @@ export const UpsertDataSourceFolderRequestSchema = z.object({ title: z.string(), mime_type: z.string(), source_url: z.string().nullable().optional(), + provider_visibility: ProviderVisibilitySchema.nullable().optional(), }); export type UpsertDataSourceFolderRequestType = z.infer< typeof UpsertDataSourceFolderRequestSchema diff --git a/sparkle/package-lock.json b/sparkle/package-lock.json index 965714bfbebe..e38cdae6da55 100644 --- a/sparkle/package-lock.json +++ b/sparkle/package-lock.json @@ -1,12 +1,12 @@ { "name": "@dust-tt/sparkle", - "version": "0.2.366", + "version": "0.2.367", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@dust-tt/sparkle", - "version": "0.2.366", + "version": "0.2.367", "license": "ISC", "dependencies": { "@emoji-mart/data": "^1.1.2", diff --git a/sparkle/package.json b/sparkle/package.json index bb063266149a..4c032b3ee356 100644 --- a/sparkle/package.json +++ b/sparkle/package.json @@ -1,6 +1,6 @@ { "name": "@dust-tt/sparkle", - "version": "0.2.366", + "version": "0.2.367", "scripts": { "build": "rm -rf dist && npm run tailwind && npm run build:esm && npm run build:cjs", "tailwind": "tailwindcss -i ./src/styles/tailwind.css -o dist/sparkle.css", diff --git a/sparkle/src/components/Citation.tsx b/sparkle/src/components/Citation.tsx index a84365c05019..2ff8d5488a36 100644 --- a/sparkle/src/components/Citation.tsx +++ b/sparkle/src/components/Citation.tsx @@ -176,11 +176,7 @@ const CitationIcons = React.forwardRef< return (
    {children} diff --git a/types/src/front/lib/connectors_api.ts b/types/src/front/lib/connectors_api.ts index 1d594e6138d4..35af0abc3a4a 100644 --- a/types/src/front/lib/connectors_api.ts +++ b/types/src/front/lib/connectors_api.ts @@ -14,7 +14,7 @@ import { Err, Ok, Result } from "../../shared/result"; export type ConnectorsAPIResponse = Result; export type ConnectorSyncStatus = "succeeded" | "failed"; -const CONNECTORS_ERROR_TYPES = [ +export const CONNECTORS_ERROR_TYPES = [ "oauth_token_revoked", "third_party_internal_error", "webcrawling_error", @@ -56,6 +56,8 @@ export type ConnectorType = { */ export type ConnectorPermission = "read" | "write" | "read_write" | "none"; export type ContentNodeType = "file" | "folder" | "database" | "channel"; +// currently used for Slack, for which channels can be public or private +export type ProviderVisibility = "public" | "private"; /* * This constant defines the priority order for sorting content nodes by their type. @@ -69,9 +71,6 @@ export const contentNodeTypeSortOrder: Record = { channel: 4, }; -// currently used for slack, for which channels can be public or private -export type ProviderVisibility = "private" | "public"; - /** * A ContentNode represents a connector related node. As an example: * - Notion: Top-level pages (possibly manually added lower level ones) diff --git a/types/src/front/lib/core_api.ts b/types/src/front/lib/core_api.ts index b61a564b2b33..c40a118ed148 100644 --- a/types/src/front/lib/core_api.ts +++ b/types/src/front/lib/core_api.ts @@ -27,6 +27,7 @@ import { import { LightWorkspaceType } from "../../front/user"; import { LoggerInterface } from "../../shared/logger"; import { Err, Ok, Result } from "../../shared/result"; +import { ProviderVisibility } from "./connectors_api"; export const MAX_CHUNK_SIZE = 512; @@ -1591,6 +1592,7 @@ export class CoreAPI { title, mimeType, sourceUrl, + providerVisibility, }: { projectId: string; dataSourceId: string; @@ -1601,6 +1603,7 @@ export class CoreAPI { title: string; mimeType: string; sourceUrl?: string | null; + providerVisibility: ProviderVisibility | null | undefined; }): Promise> { const response = await this._fetchWithError( `${this._url}/projects/${projectId}/data_sources/${encodeURIComponent( @@ -1619,6 +1622,7 @@ export class CoreAPI { parents, mime_type: mimeType, source_url: sourceUrl, + provider_visibility: providerVisibility, }), } ); diff --git a/types/src/shared/feature_flags.ts b/types/src/shared/feature_flags.ts index e9c2d47200cf..b1bf14c27ec9 100644 --- a/types/src/shared/feature_flags.ts +++ b/types/src/shared/feature_flags.ts @@ -17,6 +17,7 @@ export const WHITELISTABLE_FEATURES = [ "conversations_jit_actions", "disable_run_logs", "labs_trackers", + "show_debug_tools", ] as const; export type WhitelistableFeature = (typeof WHITELISTABLE_FEATURES)[number]; export function isWhitelistableFeature(