diff --git a/front/admin/db.ts b/front/admin/db.ts index a6dcca374673..2f8ce3bafdfd 100644 --- a/front/admin/db.ts +++ b/front/admin/db.ts @@ -1,7 +1,5 @@ import { AgentConfiguration, - AgentDatabaseQueryAction, - AgentDatabaseQueryConfiguration, AgentDataSourceConfiguration, AgentDustAppRunAction, AgentDustAppRunConfiguration, @@ -39,6 +37,11 @@ import { XP1Run, XP1User, } from "@app/lib/models"; +import { + AgentTablesQueryAction, + AgentTablesQueryConfiguration, + AgentTablesQueryConfigurationTable, +} from "@app/lib/models/assistant/actions/tables_query"; import { AgentUserRelation, GlobalAgentSettings, @@ -71,8 +74,9 @@ async function main() { await AgentDustAppRunConfiguration.sync({ alter: true }); await AgentDustAppRunAction.sync({ alter: true }); - await AgentDatabaseQueryConfiguration.sync({ alter: true }); - await AgentDatabaseQueryAction.sync({ alter: true }); + await AgentTablesQueryConfiguration.sync({ alter: true }); + await AgentTablesQueryConfigurationTable.sync({ alter: true }); + await AgentTablesQueryAction.sync({ alter: true }); await AgentGenerationConfiguration.sync({ alter: true }); await AgentRetrievalConfiguration.sync({ alter: true }); diff --git a/front/components/assistant/AssistantDetails.tsx b/front/components/assistant/AssistantDetails.tsx index c7cddd553b0c..4acff146b5f6 100644 --- a/front/components/assistant/AssistantDetails.tsx +++ b/front/components/assistant/AssistantDetails.tsx @@ -6,23 +6,29 @@ import { CommandLineIcon, Modal, PlusIcon, + ServerIcon, TrashIcon, XMarkIcon, } from "@dust-tt/sparkle"; import type { + AgentConfigurationType, AgentUsageType, AgentUserListStatus, ConnectorProvider, + CoreAPITable, + DataSourceConfiguration, + DustAppRunConfigurationType, LightAgentConfigurationType, + TablesQueryConfigurationType, + WorkspaceType, +} from "@dust-tt/types"; +import { + isDustAppRunConfiguration, + isRetrievalConfiguration, + isTablesQueryConfiguration, } from "@dust-tt/types"; -import type { DustAppRunConfigurationType } from "@dust-tt/types"; -import type { DataSourceConfiguration } from "@dust-tt/types"; -import type { AgentConfigurationType } from "@dust-tt/types"; -import type { WorkspaceType } from "@dust-tt/types"; -import { isDustAppRunConfiguration } from "@dust-tt/types"; -import { isRetrievalConfiguration } from "@dust-tt/types"; import Link from "next/link"; -import { useContext, useState } from "react"; +import { useCallback, useContext, useEffect, useState } from "react"; import ReactMarkdown from "react-markdown"; import { DeleteAssistantDialog } from "@app/components/assistant/AssistantActions"; @@ -125,6 +131,11 @@ export function AssistantDetails({ + ) : isTablesQueryConfiguration(action) ? ( +
+
Tables
+ +
) : null ) : null; @@ -379,3 +390,82 @@ function ButtonsSection({ ); } + +function TablesQuerySection({ + tablesQueryConfig, +}: { + tablesQueryConfig: TablesQueryConfigurationType; +}) { + const [tables, setTables] = useState(null); + const [isLoading, setIsLoading] = useState(false); + const [isError, setIsError] = useState(false); + + const getTables = useCallback(async () => { + const tableEndpoints = tablesQueryConfig.tables.map( + (t) => + `/api/w/${t.workspaceId}/data_sources/${t.dataSourceId}/tables/${t.tableId}` + ); + + const results = await Promise.all( + tableEndpoints.map((endpoint) => + fetch(endpoint, { + method: "GET", + headers: { "Content-Type": "application/json" }, + }) + ) + ); + + const tablesParsed = []; + for (const res of results) { + if (!res.ok) { + throw new Error((await res.json()).error.message); + } + tablesParsed.push((await res.json()).table); + } + + setTables(tablesParsed); + }, [tablesQueryConfig.tables]); + + useEffect(() => { + if (!tablesQueryConfig.tables || isLoading || isError || tables?.length) { + return; + } + setIsLoading(true); + getTables() + .catch(() => setIsError(true)) + .finally(() => setIsLoading(false)); + }, [getTables, isLoading, tablesQueryConfig.tables, tables?.length, isError]); + + if (isLoading) { + return ( +
+
Loading...
; +
+ ); + } + + if (!tables) { + return ( +
+
Error loading tables.
; +
+ ); + } + + return ( +
+
The following tables are queried before answering:
+ {tables.map((t) => ( +
+
+ +
+
{t.name}
+
+ ))} +
+ ); +} diff --git a/front/components/assistant/conversation/AgentAction.tsx b/front/components/assistant/conversation/AgentAction.tsx index 83257ca31fac..dbf60c2c6212 100644 --- a/front/components/assistant/conversation/AgentAction.tsx +++ b/front/components/assistant/conversation/AgentAction.tsx @@ -1,13 +1,13 @@ import type { AgentActionType } from "@dust-tt/types"; import { - isDatabaseQueryActionType, isDustAppRunActionType, + isRetrievalActionType, + isTablesQueryActionType, } from "@dust-tt/types"; -import { isRetrievalActionType } from "@dust-tt/types"; -import DatabaseQueryAction from "@app/components/assistant/conversation/DatabaseQueryAction"; import DustAppRunAction from "@app/components/assistant/conversation/DustAppRunAction"; import RetrievalAction from "@app/components/assistant/conversation/RetrievalAction"; +import TablesQueryAction from "@app/components/assistant/conversation/TablesQueryAction"; export function AgentAction({ action }: { action: AgentActionType }) { if (isRetrievalActionType(action)) { @@ -24,10 +24,10 @@ export function AgentAction({ action }: { action: AgentActionType }) { ); } - if (isDatabaseQueryActionType(action)) { + if (isTablesQueryActionType(action)) { return (
- +
); } diff --git a/front/components/assistant/conversation/AgentMessage.tsx b/front/components/assistant/conversation/AgentMessage.tsx index e25451f53edf..6c048fed66db 100644 --- a/front/components/assistant/conversation/AgentMessage.tsx +++ b/front/components/assistant/conversation/AgentMessage.tsx @@ -128,8 +128,8 @@ export function AgentMessage({ case "retrieval_params": case "dust_app_run_params": case "dust_app_run_block": - case "database_query_params": - case "database_query_output": + case "tables_query_params": + case "tables_query_output": setStreamedAgentMessage((m) => { return { ...m, action: event.action }; }); diff --git a/front/components/assistant/conversation/DatabaseQueryAction.tsx b/front/components/assistant/conversation/TablesQueryAction.tsx similarity index 93% rename from front/components/assistant/conversation/DatabaseQueryAction.tsx rename to front/components/assistant/conversation/TablesQueryAction.tsx index 955359d14bd6..7722e2cd84e5 100644 --- a/front/components/assistant/conversation/DatabaseQueryAction.tsx +++ b/front/components/assistant/conversation/TablesQueryAction.tsx @@ -6,7 +6,7 @@ import { Spinner, Tooltip, } from "@dust-tt/sparkle"; -import type { DatabaseQueryActionType } from "@dust-tt/types"; +import type { TablesQueryActionType } from "@dust-tt/types"; import dynamic from "next/dynamic"; import { useState } from "react"; import { amber, emerald, slate } from "tailwindcss/colors"; @@ -16,20 +16,20 @@ const SyntaxHighlighter = dynamic( { ssr: false } ); -export default function DatabaseQueryAction({ - databaseQueryAction, +export default function TablesQueryAction({ + tablesQueryAction, }: { - databaseQueryAction: DatabaseQueryActionType; + tablesQueryAction: TablesQueryActionType; }) { const [isOutputExpanded, setIsOutputExpanded] = useState(false); // Extracting question from the params - const params = databaseQueryAction.params; + const params = tablesQueryAction.params; const question = typeof params?.question === "string" ? params.question : null; // Extracting query and result from the output - const output = databaseQueryAction.output; + const output = tablesQueryAction.output; const query = typeof output?.query === "string" ? output.query : null; const noQuery = output?.no_query === true; const results = output?.results; diff --git a/front/components/assistant_builder/AssistantBuilder.tsx b/front/components/assistant_builder/AssistantBuilder.tsx index 4ad536727bc8..ef2dfd6fd807 100644 --- a/front/components/assistant_builder/AssistantBuilder.tsx +++ b/front/components/assistant_builder/AssistantBuilder.tsx @@ -47,6 +47,7 @@ import { DeleteAssistantDialog } from "@app/components/assistant/AssistantAction import { AvatarPicker } from "@app/components/assistant_builder/AssistantBuilderAvatarPicker"; import AssistantBuilderDataSourceModal from "@app/components/assistant_builder/AssistantBuilderDataSourceModal"; import AssistantBuilderDustAppModal from "@app/components/assistant_builder/AssistantBuilderDustAppModal"; +import AssistantBuilderTablesModal from "@app/components/assistant_builder/AssistantBuilderTablesModal"; import DataSourceSelectionSection from "@app/components/assistant_builder/DataSourceSelectionSection"; import DustAppSelectionSection from "@app/components/assistant_builder/DustAppSelectionSection"; import { @@ -56,6 +57,7 @@ import { SPIRIT_AVATARS_BASE_PATH, TIME_FRAME_UNIT_TO_LABEL, } from "@app/components/assistant_builder/shared"; +import TablesSelectionSection from "@app/components/assistant_builder/TablesSelectionSection"; import { TeamSharingSection } from "@app/components/assistant_builder/TeamSharingSection"; import DataSourceResourceSelectorTree from "@app/components/DataSourceResourceSelectorTree"; import AppLayout from "@app/components/sparkle/AppLayout"; @@ -102,7 +104,7 @@ const BASIC_ACTION_MODES = ["GENERIC", "RETRIEVAL_SEARCH"] as const; const ADVANCED_ACTION_MODES = [ "RETRIEVAL_EXHAUSTIVE", "DUST_APP_RUN", - "DATABASE_QUERY", + "TABLES_QUERY", ] as const; type ActionMode = @@ -114,7 +116,7 @@ const ACTION_MODE_TO_LABEL: Record = { RETRIEVAL_SEARCH: "Search in data sources", RETRIEVAL_EXHAUSTIVE: "Use most recent in data sources", DUST_APP_RUN: "Run a Dust app", - DATABASE_QUERY: "Query a database", + TABLES_QUERY: "Query a set of tables", }; // Retrieval Action @@ -147,13 +149,12 @@ export type AssistantBuilderDustAppConfiguration = { app: AppType; }; -// Database Query Action +// Tables Query Action -export type AssistantBuilderDatabaseQueryConfiguration = { +export type AssistantBuilderTableConfiguration = { dataSourceId: string; - dataSourceWorkspaceId: string; - databaseId: string; - databaseName: string; + workspaceId: string; + tableId: string; }; // Builder State @@ -169,7 +170,7 @@ type AssistantBuilderState = { unit: TimeframeUnit; }; dustAppConfiguration: AssistantBuilderDustAppConfiguration | null; - databaseQueryConfiguration: AssistantBuilderDatabaseQueryConfiguration | null; + tablesQueryConfiguration: Record; handle: string | null; description: string | null; scope: Exclude; @@ -192,7 +193,7 @@ export type AssistantBuilderInitialState = { | null; timeFrame: AssistantBuilderState["timeFrame"] | null; dustAppConfiguration: AssistantBuilderState["dustAppConfiguration"]; - databaseQueryConfiguration: AssistantBuilderState["databaseQueryConfiguration"]; + tablesQueryConfiguration: AssistantBuilderState["tablesQueryConfiguration"]; handle: string; description: string; scope: Exclude; @@ -231,7 +232,7 @@ const DEFAULT_ASSISTANT_STATE: AssistantBuilderState = { unit: "month", }, dustAppConfiguration: null, - databaseQueryConfiguration: null, + tablesQueryConfiguration: {}, handle: null, scope: "private", description: null, @@ -292,8 +293,8 @@ export default function AssistantBuilder({ ...DEFAULT_ASSISTANT_STATE.timeFrame, }, dustAppConfiguration: initialBuilderState.dustAppConfiguration, - databaseQueryConfiguration: - initialBuilderState.databaseQueryConfiguration, + tablesQueryConfiguration: + initialBuilderState.tablesQueryConfiguration, handle: initialBuilderState.handle, description: initialBuilderState.description, scope: initialBuilderState.scope, @@ -321,6 +322,8 @@ export default function AssistantBuilder({ const [showDustAppsModal, setShowDustAppsModal] = useState(false); + const [showTableModal, setShowTableModal] = useState(false); + const [edited, setEdited] = useState(defaultIsEdited ?? false); const [isSavingOrDeleting, setIsSavingOrDeleting] = useState(false); const [showDeletionModal, setShowDeletionModal] = useState(false); @@ -481,8 +484,8 @@ export default function AssistantBuilder({ } } - if (builderState.actionMode === "DATABASE_QUERY") { - if (!builderState.databaseQueryConfiguration) { + if (builderState.actionMode === "TABLES_QUERY") { + if (!builderState.tablesQueryConfiguration) { valid = false; } } @@ -496,7 +499,7 @@ export default function AssistantBuilder({ configuredDataSourceCount, builderState.timeFrame.value, builderState.dustAppConfiguration, - builderState.databaseQueryConfiguration, + builderState.tablesQueryConfiguration, assistantHandleIsAvailable, assistantHandleIsValid, ]); @@ -585,14 +588,11 @@ export default function AssistantBuilder({ } break; - case "DATABASE_QUERY": - if (builderState.databaseQueryConfiguration) { - const config = builderState.databaseQueryConfiguration; + case "TABLES_QUERY": + if (builderState.tablesQueryConfiguration) { actionParam = { - type: "database_query_configuration", - dataSourceWorkspaceId: config.dataSourceWorkspaceId, - dataSourceId: config.dataSourceId, - databaseId: config.databaseId, + type: "tables_query_configuration", + tables: Object.values(builderState.tablesQueryConfiguration), }; } break; @@ -726,6 +726,23 @@ export default function AssistantBuilder({ })); }} /> + setShowTableModal(isOpen)} + owner={owner} + dataSources={configurableDataSources} + onSave={(t) => { + setEdited(true); + setBuilderState((state) => ({ + ...state, + tablesQueryConfiguration: { + ...state.tablesQueryConfiguration, + [`${t.workspaceId}/${t.dataSourceId}/${t.tableId}`]: t, + }, + })); + }} + tablesQueryConfiguration={builderState.tablesQueryConfiguration} + /> {ADVANCED_ACTION_MODES.filter((key) => { return ( - key !== "DATABASE_QUERY" || + key !== "TABLES_QUERY" || isActivatedStructuredDB(owner) ); }).map((key) => ( @@ -1234,6 +1251,35 @@ export default function AssistantBuilder({ canSelectDustApp={dustApps.length !== 0} /> + +
+ The assistant will generate a SQL query from your request, + execute it on the tables selected and use the results to + generate an answer. +
+ { + setShowTableModal(true); + }} + onDelete={(key) => { + setEdited(true); + setBuilderState((state) => { + const tablesQueryConfiguration = + state.tablesQueryConfiguration; + delete tablesQueryConfiguration[key]; + return { + ...state, + tablesQueryConfiguration, + }; + }); + }} + canSelectTable={dataSources.length !== 0} + /> +
diff --git a/front/components/assistant_builder/AssistantBuilderTablesModal.tsx b/front/components/assistant_builder/AssistantBuilderTablesModal.tsx new file mode 100644 index 000000000000..45e4d71d779e --- /dev/null +++ b/front/components/assistant_builder/AssistantBuilderTablesModal.tsx @@ -0,0 +1,217 @@ +import { + Button, + CloudArrowDownIcon, + Item, + Modal, + Page, + ServerIcon, +} from "@dust-tt/sparkle"; +import type { CoreAPITable, DataSourceType } from "@dust-tt/types"; +import type { WorkspaceType } from "@dust-tt/types"; +import { Transition } from "@headlessui/react"; +import * as React from "react"; +import { useState } from "react"; + +import type { AssistantBuilderTableConfiguration } from "@app/components/assistant_builder/AssistantBuilder"; +import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers"; +import { useTables } from "@app/lib/swr"; + +export default function AssistantBuilderTablesModal({ + isOpen, + setOpen, + onSave, + owner, + dataSources, + tablesQueryConfiguration, +}: { + isOpen: boolean; + setOpen: (isOpen: boolean) => void; + onSave: (params: AssistantBuilderTableConfiguration) => void; + owner: WorkspaceType; + dataSources: DataSourceType[]; + tablesQueryConfiguration: Record; +}) { + const [selectedDataSource, setSelectedDataSource] = + useState(null); + + const [selectedTable, setSelectedTable] = + useState(null); + + const onClose = () => { + setOpen(false); + setTimeout(() => { + setSelectedDataSource(null); + setSelectedTable(null); + }, 200); + }; + + return ( + { + if (selectedTable) { + onSave(selectedTable); + } + }} + hasChanged={!!selectedTable} + variant="full-screen" + title="Select Tables" + > +
+ {!selectedDataSource ? ( + { + setSelectedDataSource(ds); + }} + /> + ) : ( + { + const config = { + workspaceId: owner.sId, + dataSourceId: table.data_source_id, + tableId: table.table_id, + }; + setSelectedTable(config); + onSave(config); + onClose(); + }} + onBack={() => { + setSelectedDataSource(null); + }} + tablesQueryConfiguration={tablesQueryConfiguration} + /> + )} +
+
+ ); +} + +function PickDataSource({ + dataSources, + onPick, +}: { + dataSources: DataSourceType[]; + onPick: (dataSource: DataSourceType) => void; +}) { + return ( + + + + + {dataSources + .sort( + (a, b) => + (b.connectorProvider ? 1 : 0) - (a.connectorProvider ? 1 : 0) + ) + .map((ds) => { + return ( + { + onPick(ds); + }} + /> + ); + })} + + + ); +} + +const PickTable = ({ + owner, + dataSource, + onPick, + onBack, + tablesQueryConfiguration, +}: { + owner: WorkspaceType; + dataSource: DataSourceType; + onPick: (table: CoreAPITable) => void; + onBack?: () => void; + tablesQueryConfiguration: Record; +}) => { + const { tables } = useTables({ + workspaceId: owner.sId, + dataSourceName: dataSource.name, + }); + + const tablesToDisplay = tables.filter( + (t) => + !tablesQueryConfiguration?.[ + `${owner.sId}/${dataSource.name}/${t.table_id}` + ] + ); + const isAllSelected = !!tables.length && !tablesToDisplay.length; + + return ( + + + + + {isAllSelected && ( +
+
+ All tables from this DataSource are already selected. +
+
+ )} + + {tables.length === 0 && ( +
+
+ No tables found in this Data Source. +
+
+ )} + + {!!tablesToDisplay.length && + tablesToDisplay + .sort((a, b) => (b.name ? 1 : 0) - (a.name ? 1 : 0)) + .map((table) => { + return ( + { + onPick(table); + }} + /> + ); + })} + +
+
+
+
+ ); +}; diff --git a/front/components/assistant_builder/TablesSelectionSection.tsx b/front/components/assistant_builder/TablesSelectionSection.tsx new file mode 100644 index 000000000000..e66207d0f2dd --- /dev/null +++ b/front/components/assistant_builder/TablesSelectionSection.tsx @@ -0,0 +1,99 @@ +import { + Button, + ContextItem, + PlusIcon, + ServerIcon, + TrashIcon, +} from "@dust-tt/sparkle"; +import { Transition } from "@headlessui/react"; + +import type { AssistantBuilderTableConfiguration } from "@app/components/assistant_builder/AssistantBuilder"; +import { EmptyCallToAction } from "@app/components/EmptyCallToAction"; + +export default function TablesSelectionSection({ + show, + tablesQueryConfiguration, + openTableModal, + onDelete, + canSelectTable, +}: { + show: boolean; + tablesQueryConfiguration: Record; + openTableModal: () => void; + onDelete?: (sId: string) => void; + canSelectTable: boolean; +}) { + return ( + { + window.scrollBy({ + left: 0, + top: 140, + behavior: "smooth", + }); + }} + > +
+
+
+ Select Tables: +
+
+ {Object.keys(tablesQueryConfiguration).length > 0 && ( +
+ {!Object.keys(tablesQueryConfiguration).length ? ( + + ) : ( + + {Object.values(tablesQueryConfiguration).map((t) => { + const tableKey = `${t.workspaceId}/${t.dataSourceId}/${t.tableId}`; + return ( + } + key={tableKey} + action={ + +
+ + ); +} diff --git a/front/lib/api/assistant/actions/database_query.ts b/front/lib/api/assistant/actions/tables_query.ts similarity index 58% rename from front/lib/api/assistant/actions/database_query.ts rename to front/lib/api/assistant/actions/tables_query.ts index 413b86b4699f..cd71e12208ce 100644 --- a/front/lib/api/assistant/actions/database_query.ts +++ b/front/lib/api/assistant/actions/tables_query.ts @@ -2,40 +2,41 @@ import type { AgentConfigurationType, AgentMessageType, ConversationType, - DatabaseQueryActionType, - DatabaseQueryErrorEvent, - DatabaseQueryOutputEvent, - DatabaseQueryParamsEvent, - DatabaseQuerySuccessEvent, + DustAppParameters, ModelMessageType, Result, + TablesQueryActionType, + TablesQueryErrorEvent, + TablesQueryOutputEvent, + TablesQueryParamsEvent, + TablesQuerySuccessEvent, UserMessageType, } from "@dust-tt/types"; import { cloneBaseConfig, DustProdActionRegistry, Err, - isDatabaseQueryConfiguration, + isTablesQueryConfiguration, Ok, } from "@dust-tt/types"; import { runActionStreamed } from "@app/lib/actions/server"; import { generateActionInputs } from "@app/lib/api/assistant/agent"; import type { Authenticator } from "@app/lib/auth"; -import { AgentDatabaseQueryAction } from "@app/lib/models"; +import { AgentTablesQueryAction } from "@app/lib/models/assistant/actions/tables_query"; import logger from "@app/logger/logger"; /** - * Model rendering of DatabaseQueryAction. + * Model rendering of TablesQueryAction. */ -export function renderDatabaseQueryActionForModel( - action: DatabaseQueryActionType +export function renderTablesQueryActionForModel( + action: TablesQueryActionType ): ModelMessageType { let content = ""; if (!action.output) { throw new Error( - "Output not set on DatabaseQuery action; execution is likely not finished." + "Output not set on TablesQuery action; execution is likely not finished." ); } content += `OUTPUT:\n`; @@ -43,18 +44,18 @@ export function renderDatabaseQueryActionForModel( return { role: "action" as const, - name: "DatabaseQuery", + name: "TablesQuery", content, }; } /** - * Generate the specification for the DatabaseQuery app. + * Generate the specification for the TablesQuery app. * This is the instruction given to the LLM to understand the task. */ -function getDatabaseQueryAppSpecification() { +function getTablesQueryAppSpecification() { return { - name: "query_database", + name: "query_Tables", description: "Generates a SQL query from a question in plain language, executes the generated query and return the results.", inputs: [ @@ -69,9 +70,9 @@ function getDatabaseQueryAppSpecification() { } /** - * Generate the parameters for the DatabaseQuery app. + * Generate the parameters for the TablesQuery app. */ -export async function generateDatabaseQueryAppParams( +export async function generateTablesQueryAppParams( auth: Authenticator, configuration: AgentConfigurationType, conversation: ConversationType, @@ -85,13 +86,13 @@ export async function generateDatabaseQueryAppParams( > > { const c = configuration.action; - if (!isDatabaseQueryConfiguration(c)) { + if (!isTablesQueryConfiguration(c)) { throw new Error( - "Unexpected action configuration received in `runQueryDatabase`" + "Unexpected action configuration received in `runQueryTables`" ); } - const spec = getDatabaseQueryAppSpecification(); + const spec = getTablesQueryAppSpecification(); const rawInputsRes = await generateActionInputs( auth, configuration, @@ -107,9 +108,9 @@ export async function generateDatabaseQueryAppParams( } /** - * Run the DatabaseQuery app. + * Run the TablesQuery app. */ -export async function* runDatabaseQuery({ +export async function* runTablesQuery({ auth, configuration, conversation, @@ -122,37 +123,25 @@ export async function* runDatabaseQuery({ userMessage: UserMessageType; agentMessage: AgentMessageType; }): AsyncGenerator< - | DatabaseQueryErrorEvent - | DatabaseQuerySuccessEvent - | DatabaseQueryParamsEvent - | DatabaseQueryOutputEvent + | TablesQueryErrorEvent + | TablesQuerySuccessEvent + | TablesQueryParamsEvent + | TablesQueryOutputEvent > { // Checking authorizations const owner = auth.workspace(); if (!owner) { - throw new Error("Unexpected unauthenticated call to `runQueryDatabase`"); + throw new Error("Unexpected unauthenticated call to `runQueryTables`"); } const c = configuration.action; - if (!isDatabaseQueryConfiguration(c)) { + if (!isTablesQueryConfiguration(c)) { throw new Error( - "Unexpected action configuration received in `runQueryDatabase`" + "Unexpected action configuration received in `runQueryTables`" ); } - if (owner.sId !== c.dataSourceWorkspaceId) { - yield { - type: "database_query_error", - created: Date.now(), - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: "database_query_parameters_generation_error", - message: "Cannot access the database linked to this action.", - }, - }; - } // Generating inputs - const inputRes = await generateDatabaseQueryAppParams( + const inputRes = await generateTablesQueryAppParams( auth, configuration, conversation, @@ -160,13 +149,13 @@ export async function* runDatabaseQuery({ ); if (inputRes.isErr()) { yield { - type: "database_query_error", + type: "tables_query_error", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, error: { - code: "database_query_parameters_generation_error", - message: `Error generating parameters for database_query: ${inputRes.error.message}`, + code: "tables_query_parameters_generation_error", + message: `Error generating parameters for tables_query: ${inputRes.error.message}`, }, }; return; @@ -175,66 +164,60 @@ export async function* runDatabaseQuery({ let output: Record = {}; // Creating action - const action = await AgentDatabaseQueryAction.create({ - dataSourceWorkspaceId: c.dataSourceWorkspaceId, - dataSourceId: c.dataSourceId, - databaseId: c.databaseId, - databaseQueryConfigurationId: configuration.sId, + const action = await AgentTablesQueryAction.create({ + tablesQueryConfigurationId: configuration.sId, params: input, output, }); yield { - type: "database_query_params", + type: "tables_query_params", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, action: { id: action.id, - type: "database_query_action", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, - params: action.params, - output: action.output, + type: "tables_query_action", + params: action.params as DustAppParameters, + output: action.output as Record, }, }; // Generating configuration const config = cloneBaseConfig( - DustProdActionRegistry["assistant-v2-query-database"].config + DustProdActionRegistry["assistant-v2-query-tables"].config ); - const database = { - workspace_id: c.dataSourceWorkspaceId, - data_source_id: c.dataSourceId, - database_id: c.databaseId, - }; + const tables = c.tables.map((t) => ({ + workspace_id: t.workspaceId, + table_id: t.tableId, + data_source_id: t.dataSourceId, + })); config.DATABASE_SCHEMA = { type: "database_schema", - database, + tables, }; config.DATABASE = { type: "database", - database, + tables, }; // Running the app const res = await runActionStreamed( auth, - "assistant-v2-query-database", + "assistant-v2-query-tables", config, [{ question: input.question }] ); if (res.isErr()) { yield { - type: "database_query_error", + type: "tables_query_error", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, error: { - code: "database_query_error", - message: `Error running DatabaseQuery app: ${res.error.message}`, + code: "tables_query_error", + message: `Error running TablesQuery app: ${res.error.message}`, }, }; return; @@ -249,16 +232,16 @@ export async function* runDatabaseQuery({ conversationId: conversation.id, error: event.content.message, }, - "Error running query_database app" + "Error running query_tables app" ); yield { - type: "database_query_error", + type: "tables_query_error", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, error: { - code: "database_query_error", - message: `Error running DatabaseQuery app: ${event.content.message}`, + code: "tables_query_error", + message: `Error running TablesQuery app: ${event.content.message}`, }, }; return; @@ -273,16 +256,16 @@ export async function* runDatabaseQuery({ conversationId: conversation.id, error: e.error, }, - "Error running query_database app" + "Error running query_tables app" ); yield { - type: "database_query_error", + type: "tables_query_error", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, error: { - code: "database_query_error", - message: `Error executing DatabaseQuery app: ${e.error}`, + code: "tables_query_error", + message: `Error running TablesQuery app: ${e.error}`, }, }; return; @@ -296,19 +279,17 @@ export async function* runDatabaseQuery({ } else { tmpOutput = { no_query: true }; } + yield { - type: "database_query_output", + type: "tables_query_output", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, action: { id: action.id, - type: "database_query_action", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, - params: action.params, - output: tmpOutput, + type: "tables_query_action", + params: action.params as DustAppParameters, + output: tmpOutput as Record, }, }; } @@ -328,18 +309,15 @@ export async function* runDatabaseQuery({ }); yield { - type: "database_query_success", + type: "tables_query_success", created: Date.now(), configurationId: configuration.sId, messageId: agentMessage.sId, action: { id: action.id, - type: "database_query_action", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, - params: action.params, - output: action.output, + type: "tables_query_action", + params: action.params as DustAppParameters, + output: action.output as Record, }, }; return; diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 5c423f138f95..d362cdddf02c 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -1,39 +1,37 @@ import type { AgentActionEvent, + AgentActionSpecification, AgentActionSuccessEvent, AgentConfigurationType, AgentErrorEvent, AgentGenerationCancelledEvent, AgentGenerationSuccessEvent, AgentMessageSuccessEvent, - DatabaseQueryParamsEvent, - GenerationTokensEvent, -} from "@dust-tt/types"; -import type { - AgentActionSpecification, - LightAgentConfigurationType, -} from "@dust-tt/types"; -import type { AgentMessageType, ConversationType, + GenerationTokensEvent, + LightAgentConfigurationType, + Result, + TablesQueryParamsEvent, UserMessageType, } from "@dust-tt/types"; -import type { Result } from "@dust-tt/types"; import { + cloneBaseConfig, + DustProdActionRegistry, + Err, GPT_3_5_TURBO_MODEL_CONFIG, GPT_4_32K_MODEL_CONFIG, GPT_4_MODEL_CONFIG, - isDatabaseQueryConfiguration, + isDustAppRunConfiguration, + isRetrievalConfiguration, + isTablesQueryConfiguration, + Ok, } from "@dust-tt/types"; -import { isDustAppRunConfiguration } from "@dust-tt/types"; -import { isRetrievalConfiguration } from "@dust-tt/types"; -import { cloneBaseConfig, DustProdActionRegistry } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; import { runActionStreamed } from "@app/lib/actions/server"; -import { runDatabaseQuery } from "@app/lib/api/assistant/actions/database_query"; import { runDustApp } from "@app/lib/api/assistant/actions/dust_app_run"; import { runRetrieval } from "@app/lib/api/assistant/actions/retrieval"; +import { runTablesQuery } from "@app/lib/api/assistant/actions/tables_query"; import { getAgentConfiguration } from "@app/lib/api/assistant/configuration"; import { constructPrompt, @@ -189,7 +187,7 @@ export async function* runAgent( | AgentGenerationSuccessEvent | AgentGenerationCancelledEvent | AgentMessageSuccessEvent - | DatabaseQueryParamsEvent, + | TablesQueryParamsEvent, void > { const fullConfiguration = await getAgentConfiguration( @@ -302,8 +300,8 @@ export async function* runAgent( return; } } - } else if (isDatabaseQueryConfiguration(fullConfiguration.action)) { - const eventStream = runDatabaseQuery({ + } else if (isTablesQueryConfiguration(fullConfiguration.action)) { + const eventStream = runTablesQuery({ auth, configuration: fullConfiguration, conversation, @@ -312,11 +310,11 @@ export async function* runAgent( }); for await (const event of eventStream) { switch (event.type) { - case "database_query_params": - case "database_query_output": + case "tables_query_params": + case "tables_query_output": yield event; break; - case "database_query_error": + case "tables_query_error": yield { type: "agent_error", created: event.created, @@ -328,7 +326,7 @@ export async function* runAgent( }, }; return; - case "database_query_success": + case "tables_query_success": yield { type: "agent_action_success", created: event.created, diff --git a/front/lib/api/assistant/configuration.ts b/front/lib/api/assistant/configuration.ts index 88d5398f1bca..0665c1ef631c 100644 --- a/front/lib/api/assistant/configuration.ts +++ b/front/lib/api/assistant/configuration.ts @@ -1,29 +1,27 @@ import type { + AgentActionConfigurationType, + AgentConfigurationScope, + AgentConfigurationType, + AgentGenerationConfigurationType, AgentMention, + AgentsGetViewType, + AgentStatus, AgentUserListStatus, + DataSourceConfiguration, LightAgentConfigurationType, Result, - SupportedModel, -} from "@dust-tt/types"; -import type { DustAppRunConfigurationType } from "@dust-tt/types"; -import type { - AgentsGetViewType, - DataSourceConfiguration, - RetrievalConfigurationType, RetrievalQuery, RetrievalTimeframe, + SupportedModel, } from "@dust-tt/types"; -import type { - AgentActionConfigurationType, - AgentConfigurationScope, - AgentConfigurationType, - AgentGenerationConfigurationType, - AgentStatus, +import { + assertNever, + Err, + isSupportedModel, + isTemplatedQuery, + isTimeFrame, + Ok, } from "@dust-tt/types"; -import type { DatabaseQueryConfigurationType } from "@dust-tt/types"; -import { assertNever, Err, Ok } from "@dust-tt/types"; -import { isTemplatedQuery, isTimeFrame } from "@dust-tt/types"; -import { isSupportedModel } from "@dust-tt/types"; import type { Transaction } from "sequelize"; import { Op, UniqueConstraintError } from "sequelize"; @@ -37,7 +35,6 @@ import type { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; import { AgentConfiguration, - AgentDatabaseQueryConfiguration, AgentDataSourceConfiguration, AgentDustAppRunConfiguration, AgentGenerationConfiguration, @@ -48,6 +45,10 @@ import { Message, Workspace, } from "@app/lib/models"; +import { + AgentTablesQueryConfiguration, + AgentTablesQueryConfigurationTable, +} from "@app/lib/models/assistant/actions/tables_query"; import { AgentUserRelation } from "@app/lib/models/assistant/agent"; import { generateModelSId } from "@app/lib/utils"; @@ -229,15 +230,15 @@ export async function getAgentConfigurations({ const dustAppRunConfigIds = agentConfigurations .filter((a) => a.dustAppRunConfigurationId !== null) .map((a) => a.dustAppRunConfigurationId as number); - const databaseQueryConfigIds = agentConfigurations - .filter((a) => a.databaseQueryConfigurationId !== null) - .map((a) => a.databaseQueryConfigurationId as number); + const tablesQueryConfigurationsIds = agentConfigurations + .filter((a) => a.tablesQueryConfigurationId !== null) + .map((a) => a.tablesQueryConfigurationId as number); const [ generationConfigs, retrievalConfigs, dustAppRunConfigs, - databaseQueryConfigs, + tablesQueryConfigs, agentUserRelations, ] = await Promise.all([ generationConfigIds.length > 0 @@ -255,11 +256,11 @@ export async function getAgentConfigurations({ where: { id: { [Op.in]: dustAppRunConfigIds } }, }).then(byId) : Promise.resolve({} as Record), - databaseQueryConfigIds.length > 0 && variant === "full" - ? AgentDatabaseQueryConfiguration.findAll({ - where: { id: { [Op.in]: databaseQueryConfigIds } }, + tablesQueryConfigurationsIds.length > 0 && variant === "full" + ? AgentTablesQueryConfiguration.findAll({ + where: { id: { [Op.in]: tablesQueryConfigurationsIds } }, }).then(byId) - : Promise.resolve({} as Record), + : Promise.resolve({} as Record), user && configurationIds.length > 0 ? AgentUserRelation.findAll({ where: { @@ -275,9 +276,9 @@ export async function getAgentConfigurations({ : Promise.resolve({} as Record), ]); - const agentDatasourceConfigurations = ( + const agentDatasourceConfigurationsPromise = ( Object.values(retrievalConfigs).length - ? await AgentDataSourceConfiguration.findAll({ + ? AgentDataSourceConfiguration.findAll({ where: { retrievalConfigurationId: { [Op.in]: Object.values(retrievalConfigs).map((r) => r.id), @@ -296,21 +297,46 @@ export async function getAgentConfigurations({ }, ], }) - : [] - ).reduce((acc, dsConfig) => { - acc[dsConfig.retrievalConfigurationId] = - acc[dsConfig.retrievalConfigurationId] || []; - acc[dsConfig.retrievalConfigurationId].push(dsConfig); - return acc; - }, {} as Record); + : Promise.resolve([]) + ).then((dsConfigs) => + dsConfigs.reduce((acc, dsConfig) => { + acc[dsConfig.retrievalConfigurationId] = + acc[dsConfig.retrievalConfigurationId] || []; + acc[dsConfig.retrievalConfigurationId].push(dsConfig); + return acc; + }, {} as Record) + ); + + const agentTablesConfigurationTablesPromise = ( + Object.values(tablesQueryConfigs).length + ? AgentTablesQueryConfigurationTable.findAll({ + where: { + tablesQueryConfigurationId: { + [Op.in]: Object.values(tablesQueryConfigs).map((r) => r.id), + }, + }, + }) + : Promise.resolve([]) + ).then((tablesConfigurationTables) => + tablesConfigurationTables.reduce((acc, tablesConfigurationTable) => { + acc[tablesConfigurationTable.tablesQueryConfigurationId] = + acc[tablesConfigurationTable.tablesQueryConfigurationId] || []; + acc[tablesConfigurationTable.tablesQueryConfigurationId].push( + tablesConfigurationTable + ); + return acc; + }, {} as Record) + ); + + const [agentDatasourceConfigurations, agentTablesConfigurationTables] = + await Promise.all([ + agentDatasourceConfigurationsPromise, + agentTablesConfigurationTablesPromise, + ]); let agentConfigurationTypes: AgentConfigurationType[] = []; for (const agent of agentConfigurations) { - let action: - | RetrievalConfigurationType - | DustAppRunConfigurationType - | DatabaseQueryConfigurationType - | null = null; + let action: AgentActionConfigurationType | null = null; if (variant === "full") { if (agent.retrievalConfigurationId) { @@ -325,6 +351,7 @@ export async function getAgentConfigurations({ const dataSourcesConfig = agentDatasourceConfigurations[retrievalConfig.id] ?? []; + let topK: number | "auto" = "auto"; if (retrievalConfig.topKMode === "custom") { if (!retrievalConfig.topK) { @@ -378,23 +405,30 @@ export async function getAgentConfigurations({ appWorkspaceId: dustAppRunConfig.appWorkspaceId, appId: dustAppRunConfig.appId, }; - } else if (agent.databaseQueryConfigurationId) { - const databaseQueryConfig = - databaseQueryConfigs[agent.databaseQueryConfigurationId]; + } else if (agent.tablesQueryConfigurationId) { + const tablesQueryConfig = + tablesQueryConfigs[agent.tablesQueryConfigurationId]; - if (!databaseQueryConfig) { + if (!tablesQueryConfig) { throw new Error( - `Couldn't find action configuration for Database configuration ${agent.databaseQueryConfigurationId}}` + `Couldn't find action configuration for Tables configuration ${agent.tablesQueryConfigurationId}}` ); } + const tablesQueryConfigTables = + agentTablesConfigurationTables[tablesQueryConfig.id] ?? []; + action = { - id: databaseQueryConfig.id, - sId: databaseQueryConfig.sId, - type: "database_query_configuration", - dataSourceWorkspaceId: databaseQueryConfig.dataSourceWorkspaceId, - dataSourceId: databaseQueryConfig.dataSourceId, - databaseId: databaseQueryConfig.databaseId, + id: tablesQueryConfig.id, + sId: tablesQueryConfig.sId, + type: "tables_query_configuration", + tables: tablesQueryConfigTables.map((tablesQueryConfigTable) => { + return { + dataSourceId: tablesQueryConfigTable.dataSourceId, + workspaceId: tablesQueryConfigTable.dataSourceWorkspaceId, + tableId: tablesQueryConfigTable.tableId, + }; + }), }; } } @@ -656,10 +690,8 @@ export async function createAgentConfiguration( action?.type === "retrieval_configuration" ? action?.id : null, dustAppRunConfigurationId: action?.type === "dust_app_run_configuration" ? action?.id : null, - databaseQueryConfigurationId: - action?.type === "database_query_configuration" - ? action?.id - : null, + tablesQueryConfigurationId: + action?.type === "tables_query_configuration" ? action?.id : null, }, { transaction: t, @@ -793,10 +825,12 @@ export async function createAgentActionConfiguration( appId: string; } | { - type: "database_query_configuration"; - dataSourceWorkspaceId: string; - dataSourceId: string; - databaseId: string; + type: "tables_query_configuration"; + tables: Array<{ + workspaceId: string; + dataSourceId: string; + tableId: string; + }>; } ): Promise { const owner = auth.workspace(); @@ -857,21 +891,35 @@ export async function createAgentActionConfiguration( appWorkspaceId: action.appWorkspaceId, appId: action.appId, }; - } else if (action.type === "database_query_configuration") { - const databaseQueryConfig = await AgentDatabaseQueryConfiguration.create({ - sId: generateModelSId(), - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, + } else if (action.type === "tables_query_configuration") { + return front_sequelize.transaction(async (t) => { + const tablesQueryConfig = await AgentTablesQueryConfiguration.create( + { + sId: generateModelSId(), + }, + { transaction: t } + ); + await Promise.all( + action.tables.map((table) => + AgentTablesQueryConfigurationTable.create( + { + tablesQueryConfigurationId: tablesQueryConfig.id, + dataSourceId: table.dataSourceId, + dataSourceWorkspaceId: table.workspaceId, + tableId: table.tableId, + }, + { transaction: t } + ) + ) + ); + + return { + id: tablesQueryConfig.id, + sId: tablesQueryConfig.sId, + type: "tables_query_configuration", + tables: action.tables, + }; }); - return { - id: databaseQueryConfig.id, - sId: databaseQueryConfig.sId, - type: "database_query_configuration", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, - }; } else { throw new Error("Cannot create AgentActionConfiguration: unknow type"); } diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index ba9261f8758c..5af26befc807 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -1,6 +1,7 @@ import type { AgentMessageNewEvent, ConversationTitleEvent, + DustAppParameters, GenerationTokensEvent, UserMessageErrorEvent, UserMessageNewEvent, @@ -52,7 +53,6 @@ import { renderConversationForModel } from "@app/lib/api/assistant/generation"; import type { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; import { - AgentDatabaseQueryAction, AgentDustAppRunAction, AgentMessage, Conversation, @@ -62,6 +62,7 @@ import { User, UserMessage, } from "@app/lib/models"; +import { AgentTablesQueryAction } from "@app/lib/models/assistant/actions/tables_query"; import { ContentFragment } from "@app/lib/models/assistant/conversation"; import { updateWorkspacePerMonthlyActiveUsersSubscriptionUsage } from "@app/lib/plans/subscription"; import { generateModelSId } from "@app/lib/utils"; @@ -290,7 +291,7 @@ async function batchRenderAgentMessages( agentConfigurations, agentRetrievalActions, agentDustAppRunActions, - agentDatabaseQueryActions, + agentTablesQueryActions, ] = await Promise.all([ (async () => { const agentConfigurationIds: string[] = agentMessages.reduce( @@ -347,24 +348,21 @@ async function batchRenderAgentMessages( }); })(), (async () => { - const actions = await AgentDatabaseQueryAction.findAll({ + const actions = await AgentTablesQueryAction.findAll({ where: { id: { [Op.in]: agentMessages - .filter((m) => m.agentMessage?.agentDatabaseQueryActionId) - .map((m) => m.agentMessage?.agentDatabaseQueryActionId as number), + .filter((m) => m.agentMessage?.agentTablesQueryActionId) + .map((m) => m.agentMessage?.agentTablesQueryActionId as number), }, }, }); return actions.map((action) => { return { id: action.id, - type: "database_query_action", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, - params: action.params, - output: action.output, + type: "tables_query_action", + params: action.params as DustAppParameters, + output: action.output as Record, }; }); })(), @@ -387,12 +385,11 @@ async function batchRenderAgentMessages( action = agentDustAppRunActions.find( (a) => a.id === agentMessage.agentDustAppRunActionId ); - } else if (agentMessage.agentDatabaseQueryActionId) { - action = agentDatabaseQueryActions.find( - (a) => a.id === agentMessage.agentDatabaseQueryActionId + } else if (agentMessage.agentTablesQueryActionId) { + action = agentTablesQueryActions.find( + (a) => a.id === agentMessage.agentTablesQueryActionId ); } - const agentConfiguration = agentConfigurations.find( (a) => a.sId === message.agentMessage?.agentConfigurationId ); @@ -1941,9 +1938,9 @@ async function* streamRunAgentEvents( await agentMessageRow.update({ agentDustAppRunActionId: event.action.id, }); - } else if (event.action.type === "database_query_action") { + } else if (event.action.type === "tables_query_action") { await agentMessageRow.update({ - agentDatabaseQueryActionId: event.action.id, + agentTablesQueryActionId: event.action.id, }); } else { ((action: never) => { @@ -1987,8 +1984,8 @@ async function* streamRunAgentEvents( case "retrieval_params": case "dust_app_run_params": case "dust_app_run_block": - case "database_query_params": - case "database_query_output": + case "tables_query_params": + case "tables_query_output": yield event; break; case "generation_tokens": diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index d1343f39fe1f..953f8b5f345c 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -1,45 +1,42 @@ import type { AgentConfigurationType, + AgentMessageType, + ConversationType, GenerationCancelEvent, GenerationErrorEvent, GenerationSuccessEvent, GenerationTokensEvent, ModelConversationType, ModelMessageType, -} from "@dust-tt/types"; -import type { - AgentMessageType, - ConversationType, + Result, UserMessageType, } from "@dust-tt/types"; -import type { Result } from "@dust-tt/types"; import { + assertNever, + cloneBaseConfig, + CoreAPI, + DustProdActionRegistry, + Err, GPT_4_32K_MODEL_ID, GPT_4_MODEL_CONFIG, - isDatabaseQueryActionType, + isAgentMessageType, + isContentFragmentType, isDustAppRunActionType, -} from "@dust-tt/types"; -import { isRetrievalActionType, isRetrievalConfiguration, -} from "@dust-tt/types"; -import { - isAgentMessageType, - isContentFragmentType, + isTablesQueryActionType, isUserMessageType, + Ok, } from "@dust-tt/types"; -import { cloneBaseConfig, DustProdActionRegistry } from "@dust-tt/types"; -import { CoreAPI } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; import moment from "moment-timezone"; import { runActionStreamed } from "@app/lib/actions/server"; -import { renderDatabaseQueryActionForModel } from "@app/lib/api/assistant/actions/database_query"; import { renderDustAppRunActionForModel } from "@app/lib/api/assistant/actions/dust_app_run"; import { renderRetrievalActionForModel, retrievalMetaPrompt, } from "@app/lib/api/assistant/actions/retrieval"; +import { renderTablesQueryActionForModel } from "@app/lib/api/assistant/actions/tables_query"; import { getAgentConfigurations } from "@app/lib/api/assistant/configuration"; import { getSupportedModelConfig, isLargeModel } from "@app/lib/assistant"; import type { Authenticator } from "@app/lib/auth"; @@ -87,14 +84,10 @@ export async function renderConversationForModel({ } } else if (isDustAppRunActionType(m.action)) { messages.unshift(renderDustAppRunActionForModel(m.action)); - } else if (isDatabaseQueryActionType(m.action)) { - messages.unshift(renderDatabaseQueryActionForModel(m.action)); + } else if (isTablesQueryActionType(m.action)) { + messages.unshift(renderTablesQueryActionForModel(m.action)); } else { - return new Err( - new Error( - "Unsupported action type during conversation model rendering" - ) - ); + assertNever(m.action); } } if (m.content) { diff --git a/front/lib/api/assistant/pubsub.ts b/front/lib/api/assistant/pubsub.ts index dc3c5fbd32cd..b3abaa721604 100644 --- a/front/lib/api/assistant/pubsub.ts +++ b/front/lib/api/assistant/pubsub.ts @@ -161,8 +161,8 @@ async function handleUserMessageEvents( case "retrieval_params": case "dust_app_run_params": case "dust_app_run_block": - case "database_query_params": - case "database_query_output": + case "tables_query_params": + case "tables_query_output": case "agent_error": case "agent_action_success": case "generation_tokens": @@ -286,8 +286,8 @@ export async function retryAgentMessageWithPubSub( case "retrieval_params": case "dust_app_run_params": case "dust_app_run_block": - case "database_query_params": - case "database_query_output": + case "tables_query_params": + case "tables_query_output": case "agent_error": case "agent_action_success": case "generation_tokens": diff --git a/front/lib/api/assistant/user_relation.ts b/front/lib/api/assistant/user_relation.ts index 1e049b179bcd..d3408a41257c 100644 --- a/front/lib/api/assistant/user_relation.ts +++ b/front/lib/api/assistant/user_relation.ts @@ -40,6 +40,7 @@ export async function getAgentUserListStatus({ agentId: string; }): Promise> { const agentConfiguration = await getAgentConfiguration(auth, agentId); + if (!agentConfiguration) return new Err(new Error(`Could not find agent configuration ${agentId}`)); diff --git a/front/lib/models/assistant/actions/database_query.ts b/front/lib/models/assistant/actions/database_query.ts deleted file mode 100644 index e13d67cedd6d..000000000000 --- a/front/lib/models/assistant/actions/database_query.ts +++ /dev/null @@ -1,141 +0,0 @@ -import type { - CreationOptional, - InferAttributes, - InferCreationAttributes, -} from "sequelize"; -import { DataTypes, Model } from "sequelize"; - -import { front_sequelize } from "@app/lib/databases"; - -/** - * Agent Database Query Configuration - */ -export class AgentDatabaseQueryConfiguration extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare sId: string; - - declare dataSourceWorkspaceId: string; - declare dataSourceId: string; - declare databaseId: string; -} - -AgentDatabaseQueryConfiguration.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - sId: { - type: DataTypes.STRING, - allowNull: false, - }, - dataSourceWorkspaceId: { - type: DataTypes.STRING, - allowNull: false, - }, - dataSourceId: { - type: DataTypes.STRING, - allowNull: false, - }, - databaseId: { - type: DataTypes.STRING, - allowNull: false, - }, - }, - { - modelName: "agent_database_query_configuration", - indexes: [ - { - unique: true, - fields: ["sId"], - }, - ], - sequelize: front_sequelize, - } -); - -/** - * Agent Database Query Action - */ -export class AgentDatabaseQueryAction extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare databaseQueryConfigurationId: string; - - declare dataSourceWorkspaceId: string; - declare dataSourceId: string; - declare databaseId: string; - - declare params: unknown | null; - declare output: unknown | null; -} - -AgentDatabaseQueryAction.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - databaseQueryConfigurationId: { - type: DataTypes.STRING, - allowNull: false, - }, - dataSourceWorkspaceId: { - type: DataTypes.STRING, - allowNull: false, - }, - dataSourceId: { - type: DataTypes.STRING, - allowNull: false, - }, - databaseId: { - type: DataTypes.STRING, - allowNull: false, - }, - params: { - type: DataTypes.JSONB, - allowNull: true, - }, - output: { - type: DataTypes.JSONB, - allowNull: true, - }, - }, - { - modelName: "agent_database_query_action", - sequelize: front_sequelize, - } -); diff --git a/front/lib/models/assistant/actions/tables_query.ts b/front/lib/models/assistant/actions/tables_query.ts new file mode 100644 index 000000000000..d15cc35e13f6 --- /dev/null +++ b/front/lib/models/assistant/actions/tables_query.ts @@ -0,0 +1,182 @@ +import type { + CreationOptional, + ForeignKey, + InferAttributes, + InferCreationAttributes, +} from "sequelize"; +import { DataTypes, Model } from "sequelize"; + +import { front_sequelize } from "@app/lib/databases"; + +export class AgentTablesQueryConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare sId: string; +} + +AgentTablesQueryConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + sId: { + type: DataTypes.STRING, + allowNull: false, + }, + }, + { + modelName: "agent_tables_query_configuration", + indexes: [ + { + unique: true, + fields: ["sId"], + name: "agent_tables_query_configuration_s_id", + }, + ], + sequelize: front_sequelize, + } +); + +export class AgentTablesQueryConfigurationTable extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare dataSourceWorkspaceId: string; + declare dataSourceId: string; + declare tableId: string; + + declare tablesQueryConfigurationId: ForeignKey< + AgentTablesQueryConfiguration["id"] + >; +} + +AgentTablesQueryConfigurationTable.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + + dataSourceWorkspaceId: { + type: DataTypes.STRING, + allowNull: false, + }, + dataSourceId: { + type: DataTypes.STRING, + allowNull: false, + }, + tableId: { + type: DataTypes.STRING, + allowNull: false, + }, + }, + { + modelName: "agent_tables_query_configuration_table", + indexes: [ + { + unique: true, + fields: [ + "dataSourceWorkspaceId", + "dataSourceId", + "tableId", + "tablesQueryConfigurationId", + ], + name: "agent_tables_query_configuration_table_unique", + }, + ], + sequelize: front_sequelize, + } +); + +AgentTablesQueryConfiguration.hasMany(AgentTablesQueryConfigurationTable, { + foreignKey: { name: "tablesQueryConfigurationId", allowNull: false }, + onDelete: "CASCADE", +}); +AgentTablesQueryConfigurationTable.belongsTo(AgentTablesQueryConfiguration, { + foreignKey: { name: "tablesQueryConfigurationId", allowNull: false }, + onDelete: "CASCADE", +}); + +export class AgentTablesQueryAction extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare tablesQueryConfigurationId: string; + + declare params: unknown | null; + declare output: unknown | null; +} + +AgentTablesQueryAction.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + + tablesQueryConfigurationId: { + type: DataTypes.STRING, + allowNull: false, + }, + + params: { + type: DataTypes.JSONB, + allowNull: true, + }, + output: { + type: DataTypes.JSONB, + allowNull: true, + }, + }, + { + modelName: "agent_tables_query_action", + sequelize: front_sequelize, + } +); diff --git a/front/lib/models/assistant/agent.ts b/front/lib/models/assistant/agent.ts index e74bcd7f2f37..2ff496f9587c 100644 --- a/front/lib/models/assistant/agent.ts +++ b/front/lib/models/assistant/agent.ts @@ -17,9 +17,9 @@ import type { import { DataTypes, Model } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; -import { AgentDatabaseQueryConfiguration } from "@app/lib/models/assistant/actions/database_query"; import { AgentDustAppRunConfiguration } from "@app/lib/models/assistant/actions/dust_app_run"; import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/actions/retrieval"; +import { AgentTablesQueryConfiguration } from "@app/lib/models/assistant/actions/tables_query"; import { User } from "@app/lib/models/user"; import { Workspace } from "@app/lib/models/workspace"; @@ -102,6 +102,7 @@ export class AgentConfiguration extends Model< declare workspaceId: ForeignKey; declare authorId: ForeignKey; + declare generationConfigurationId: ForeignKey< AgentGenerationConfiguration["id"] > | null; @@ -112,15 +113,15 @@ export class AgentConfiguration extends Model< AgentDustAppRunConfiguration["id"] > | null; - declare databaseQueryConfigurationId: ForeignKey< - AgentDatabaseQueryConfiguration["id"] + declare tablesQueryConfigurationId: ForeignKey< + AgentTablesQueryConfiguration["id"] > | null; declare author: NonAttribute; declare generationConfiguration: NonAttribute; declare retrievalConfiguration: NonAttribute; declare dustAppRunConfiguration: NonAttribute; - declare databaseQueryConfiguration: NonAttribute; + declare tablesQueryConfiguration: NonAttribute; } AgentConfiguration.init( { @@ -195,14 +196,14 @@ AgentConfiguration.init( const actionsTypes: (keyof AgentConfiguration)[] = [ "retrievalConfigurationId", "dustAppRunConfigurationId", - "databaseQueryConfigurationId", + "tablesQueryConfigurationId", ]; const nonNullActionTypes = actionsTypes.filter( (field) => agentConfiguration[field] != null ); if (nonNullActionTypes.length > 1) { throw new Error( - "Only one of retrievalConfigurationId, dustAppRunConfigurationId, or databaseQueryConfigurationId can be set" + "Only one of retrievalConfigurationId, dustAppRunConfigurationId, tablesQueryConfigurationId can be set" ); } }, @@ -249,14 +250,14 @@ AgentConfiguration.belongsTo(AgentDustAppRunConfiguration, { foreignKey: { name: "dustAppRunConfigurationId", allowNull: true }, // null = no DutsAppRun action set for this Agent }); -// Agent config <> Database config -AgentDatabaseQueryConfiguration.hasOne(AgentConfiguration, { - as: "databaseQueryConfiguration", - foreignKey: { name: "databaseQueryConfigurationId", allowNull: true }, // null = no Database action set for this Agent +// Agent config <> Tables config +AgentTablesQueryConfiguration.hasOne(AgentConfiguration, { + as: "tablesQueryConfiguration", + foreignKey: { name: "tablesQueryConfigurationId", allowNull: true }, // null = no Tables action set for this Agent }); -AgentConfiguration.belongsTo(AgentDatabaseQueryConfiguration, { - as: "databaseQueryConfiguration", - foreignKey: { name: "databaseQueryConfigurationId", allowNull: true }, // null = no Database action set for this Agent +AgentConfiguration.belongsTo(AgentTablesQueryConfiguration, { + as: "tablesQueryConfiguration", + foreignKey: { name: "tablesQueryConfigurationId", allowNull: true }, // null = no Tables action set for this Agent }); // Agent config <> Author diff --git a/front/lib/models/assistant/conversation.ts b/front/lib/models/assistant/conversation.ts index 896efa880efd..e8bd0f58382f 100644 --- a/front/lib/models/assistant/conversation.ts +++ b/front/lib/models/assistant/conversation.ts @@ -15,9 +15,9 @@ import type { import { DataTypes, Model } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; -import { AgentDatabaseQueryAction } from "@app/lib/models/assistant/actions/database_query"; import { AgentDustAppRunAction } from "@app/lib/models/assistant/actions/dust_app_run"; import { AgentRetrievalAction } from "@app/lib/models/assistant/actions/retrieval"; +import { AgentTablesQueryAction } from "@app/lib/models/assistant/actions/tables_query"; import { User } from "@app/lib/models/user"; import { Workspace } from "@app/lib/models/workspace"; @@ -248,8 +248,8 @@ export class AgentMessage extends Model< AgentDustAppRunAction["id"] > | null; - declare agentDatabaseQueryActionId: ForeignKey< - AgentDatabaseQueryAction["id"] + declare agentTablesQueryActionId: ForeignKey< + AgentTablesQueryAction["id"] > | null; // Not a relation as global agents are not in the DB @@ -316,14 +316,14 @@ AgentMessage.init( const actionsTypes: (keyof AgentMessage)[] = [ "agentRetrievalActionId", "agentDustAppRunActionId", - "agentDatabaseQueryActionId", + "agentTablesQueryActionId", ]; const nonNullActionTypes = actionsTypes.filter( (field) => agentMessage[field] != null ); if (nonNullActionTypes.length > 1) { throw new Error( - "Only one of agentRetrievalActionId, agentDustAppRunActionId, or agentDatabaseQueryActionId can be set" + "Only one of agentRetrievalActionId, agentDustAppRunActionId or agentTablesQueryActionId can be set" ); } }, @@ -346,12 +346,12 @@ AgentDustAppRunAction.hasOne(AgentMessage, { AgentMessage.belongsTo(AgentDustAppRunAction, { foreignKey: { name: "agentDustAppRunActionId", allowNull: true }, // null = no DustAppRun action set for this Agent }); -AgentDatabaseQueryAction.hasOne(AgentMessage, { - foreignKey: { name: "agentDatabaseQueryActionId", allowNull: true }, // null = no Database Query action set for this Agent +AgentTablesQueryAction.hasOne(AgentMessage, { + foreignKey: { name: "agentTablesQueryActionId", allowNull: true }, // null = no Tables Query action set for this Agent onDelete: "CASCADE", }); -AgentMessage.belongsTo(AgentDatabaseQueryAction, { - foreignKey: { name: "agentDatabaseQueryActionId", allowNull: true }, // null = no Database Query action set for this Agent +AgentMessage.belongsTo(AgentTablesQueryAction, { + foreignKey: { name: "agentTablesQueryActionId", allowNull: true }, // null = no Tables Query action set for this Agent }); export class ContentFragment extends Model< diff --git a/front/lib/models/index.ts b/front/lib/models/index.ts index 394344c7aeb2..74f305f59bc0 100644 --- a/front/lib/models/index.ts +++ b/front/lib/models/index.ts @@ -1,8 +1,4 @@ import { App, Clone, Dataset, Provider, Run } from "@app/lib/models/apps"; -import { - AgentDatabaseQueryAction, - AgentDatabaseQueryConfiguration, -} from "@app/lib/models/assistant/actions/database_query"; import { AgentDustAppRunAction, AgentDustAppRunConfiguration, @@ -14,6 +10,10 @@ import { RetrievalDocument, RetrievalDocumentChunk, } from "@app/lib/models/assistant/actions/retrieval"; +import { + AgentTablesQueryAction, + AgentTablesQueryConfiguration, +} from "@app/lib/models/assistant/actions/tables_query"; import { AgentConfiguration, AgentGenerationConfiguration, @@ -47,8 +47,6 @@ import { XP1Run, XP1User } from "@app/lib/models/xp1"; export { AgentConfiguration, - AgentDatabaseQueryAction, - AgentDatabaseQueryConfiguration, AgentDataSourceConfiguration, AgentDustAppRunAction, AgentDustAppRunConfiguration, @@ -56,6 +54,8 @@ export { AgentMessage, AgentRetrievalAction, AgentRetrievalConfiguration, + AgentTablesQueryAction, + AgentTablesQueryConfiguration, App, Clone, Conversation, diff --git a/front/lib/swr.ts b/front/lib/swr.ts index a91c0201dca0..ba132d7e920c 100644 --- a/front/lib/swr.ts +++ b/front/lib/swr.ts @@ -1,16 +1,14 @@ import type { AgentConfigurationType, AgentsGetViewType, - DataSourceType, -} from "@dust-tt/types"; -import type { WorkspaceType } from "@dust-tt/types"; -import type { + AppType, + ConnectorPermission, ConversationMessageReactions, ConversationType, + DataSourceType, + RunRunType, + WorkspaceType, } from "@dust-tt/types"; -import type { AppType } from "@dust-tt/types"; -import type { RunRunType } from "@dust-tt/types"; -import type { ConnectorPermission } from "@dust-tt/types"; import type { Fetcher } from "swr"; import useSWR from "swr"; @@ -18,6 +16,7 @@ import type { GetPokePlansResponseBody } from "@app/pages/api/poke/plans"; import type { GetWorkspacesResponseBody } from "@app/pages/api/poke/workspaces"; import type { GetUserMetadataResponseBody } from "@app/pages/api/user/metadata/[key]"; import type { ListTablesResponseBody } from "@app/pages/api/v1/w/[wId]/data_sources/[name]/tables"; +import type { GetTableResponseBody } from "@app/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]"; import type { GetDatasetsResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/datasets"; import type { GetRunsResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/runs"; import type { GetRunBlockResponseBody } from "@app/pages/api/w/[wId]/apps/[aId]/runs/[runId]/blocks/[type]/[name]"; @@ -584,6 +583,32 @@ export function useTables({ }; } +export function useTable({ + workspaceId, + dataSourceName, + tableId, +}: { + workspaceId: string; + dataSourceName: string; + tableId: string; +}) { + const tableFetcher: Fetcher = fetcher; + + const { data, error, mutate } = useSWR( + dataSourceName + ? `/api/w/${workspaceId}/data_sources/${dataSourceName}/tables/${tableId}` + : null, + tableFetcher + ); + + return { + table: data ? data.table : null, + isTableLoading: !error && !data, + isTableError: error, + mutateTable: mutate, + }; +} + export function useApp({ workspaceId, appId, diff --git a/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts index f9c616a5f038..bd38ee0c8abd 100644 --- a/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts +++ b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts @@ -8,7 +8,7 @@ import { isActivatedStructuredDB } from "@app/lib/development"; import logger from "@app/logger/logger"; import { apiError, withLogging } from "@app/logger/withlogging"; -type GetTableResponseBody = { +export type GetTableResponseBody = { table: CoreAPITable; }; diff --git a/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts b/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts index 4688de27aac7..9aa89bcecc77 100644 --- a/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts +++ b/front/pages/api/w/[wId]/assistant/agent_configurations/index.ts @@ -268,12 +268,10 @@ export async function createOrUpgradeAgentConfiguration( appId: action.appId, }); } - if (action && action.type === "database_query_configuration") { + if (action && action.type === "tables_query_configuration") { actionConfig = await createAgentActionConfiguration(auth, { - type: "database_query_configuration", - dataSourceWorkspaceId: action.dataSourceWorkspaceId, - dataSourceId: action.dataSourceId, - databaseId: action.databaseId, + type: "tables_query_configuration", + tables: action.tables, }); } diff --git a/front/pages/api/w/[wId]/data_sources/[name]/tables/[tId]/index.ts b/front/pages/api/w/[wId]/data_sources/[name]/tables/[tId]/index.ts new file mode 100644 index 000000000000..c631e1ec6c8e --- /dev/null +++ b/front/pages/api/w/[wId]/data_sources/[name]/tables/[tId]/index.ts @@ -0,0 +1,125 @@ +import type { CoreAPITable } from "@dust-tt/types"; +import { CoreAPI } from "@dust-tt/types"; +import type { NextApiRequest, NextApiResponse } from "next"; + +import { getDataSource } from "@app/lib/api/data_sources"; +import { Authenticator, getSession } from "@app/lib/auth"; +import { isActivatedStructuredDB } from "@app/lib/development"; +import logger from "@app/logger/logger"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +export type GetTableResponseBody = { + table: CoreAPITable; +}; + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const session = await getSession(req, res); + const auth = await Authenticator.fromSession( + session, + req.query.wId as string + ); + + const owner = auth.workspace(); + if (!owner || !auth.isBuilder()) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "data_source_not_found", + message: "The data source you requested was not found.", + }, + }); + } + + const plan = auth.plan(); + if (!owner || !plan) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_not_found", + message: "The workspace you requested was not found.", + }, + }); + } + + if (!isActivatedStructuredDB(owner)) { + res.status(404).end(); + return; + } + + if (!req.query.name || typeof req.query.name !== "string") { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "data_source_not_found", + message: "The data source you requested was not found.", + }, + }); + } + + const dataSource = await getDataSource(auth, req.query.name); + if (!dataSource) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "data_source_not_found", + message: "The data source you requested was not found.", + }, + }); + } + + const tableId = req.query.tId; + if (!tableId || typeof tableId !== "string") { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: "The table id is missing.", + }, + }); + } + + switch (req.method) { + case "GET": + const coreAPI = new CoreAPI(logger); + const tableRes = await coreAPI.getTable({ + projectId: dataSource.dustAPIProjectId, + dataSourceName: dataSource.name, + tableId, + }); + if (tableRes.isErr()) { + logger.error( + { + dataSourcename: dataSource.name, + workspaceId: owner.id, + error: tableRes.error, + }, + "Failed to get table." + ); + return apiError(req, res, { + status_code: 500, + api_error: { + type: "internal_server_error", + message: "Failed to get table.", + }, + }); + } + + const { table } = tableRes.value; + + return res.status(200).json({ table }); + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, GET is expected.", + }, + }); + } +} + +export default withLogging(handler); diff --git a/front/pages/w/[wId]/builder/assistants/[aId]/index.tsx b/front/pages/w/[wId]/builder/assistants/[aId]/index.tsx index 10391e2b929f..69e64f56829b 100644 --- a/front/pages/w/[wId]/builder/assistants/[aId]/index.tsx +++ b/front/pages/w/[wId]/builder/assistants/[aId]/index.tsx @@ -1,15 +1,18 @@ import type { AgentConfigurationType, + AppType, DataSourceType, + PlanType, + SubscriptionType, UserType, WorkspaceType, } from "@dust-tt/types"; -import type { AppType } from "@dust-tt/types"; -import type { PlanType, SubscriptionType } from "@dust-tt/types"; -import { isDatabaseQueryConfiguration } from "@dust-tt/types"; -import { isDustAppRunConfiguration } from "@dust-tt/types"; -import { isRetrievalConfiguration } from "@dust-tt/types"; -import { ConnectorsAPI } from "@dust-tt/types"; +import { + ConnectorsAPI, + isDustAppRunConfiguration, + isRetrievalConfiguration, + isTablesQueryConfiguration, +} from "@dust-tt/types"; import type { GetServerSideProps, InferGetServerSidePropsType } from "next"; import type { @@ -110,13 +113,30 @@ export async function buildInitialState({ } } - const databaseQueryConfiguration: AssistantBuilderInitialState["databaseQueryConfiguration"] = - null; + let tablesQueryConfiguration: AssistantBuilderInitialState["tablesQueryConfiguration"] = + {}; + + if ( + isTablesQueryConfiguration(config.action) && + config.action.tables.length + ) { + tablesQueryConfiguration = config.action.tables.reduce( + (acc, curr) => ({ + ...acc, + [`${curr.workspaceId}/${curr.dataSourceId}/${curr.tableId}`]: { + workspaceId: curr.workspaceId, + dataSourceId: curr.dataSourceId, + tableId: curr.tableId, + }, + }), + {} as AssistantBuilderInitialState["tablesQueryConfiguration"] + ); + } return { dataSourceConfigurations, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, }; } @@ -130,7 +150,7 @@ export const getServerSideProps: GetServerSideProps<{ dataSourceConfigurations: Record; dustApps: AppType[]; dustAppConfiguration: AssistantBuilderInitialState["dustAppConfiguration"]; - databaseQueryConfiguration: AssistantBuilderInitialState["databaseQueryConfiguration"]; + tablesQueryConfiguration: AssistantBuilderInitialState["tablesQueryConfiguration"]; agentConfiguration: AgentConfigurationType; flow: BuilderFlow; }> = async (context) => { @@ -190,7 +210,7 @@ export const getServerSideProps: GetServerSideProps<{ const { dataSourceConfigurations, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, } = await buildInitialState({ config, dataSourceByName, @@ -208,7 +228,7 @@ export const getServerSideProps: GetServerSideProps<{ dataSourceConfigurations, dustApps: allDustApps, dustAppConfiguration, - databaseQueryConfiguration: databaseQueryConfiguration, + tablesQueryConfiguration, agentConfiguration: config, flow, }, @@ -225,7 +245,7 @@ export default function EditAssistant({ dataSourceConfigurations, dustApps, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, agentConfiguration, flow, }: InferGetServerSidePropsType) { @@ -259,8 +279,8 @@ export default function EditAssistant({ actionMode = "DUST_APP_RUN"; } - if (isDatabaseQueryConfiguration(agentConfiguration.action)) { - actionMode = "DATABASE_QUERY"; + if (isTablesQueryConfiguration(agentConfiguration.action)) { + actionMode = "TABLES_QUERY"; } if (agentConfiguration.scope === "global") { throw new Error("Cannot edit global assistant"); @@ -281,7 +301,7 @@ export default function EditAssistant({ timeFrame, dataSourceConfigurations, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, scope: agentConfiguration.scope, handle: agentConfiguration.name, description: agentConfiguration.description, diff --git a/front/pages/w/[wId]/builder/assistants/new.tsx b/front/pages/w/[wId]/builder/assistants/new.tsx index 90ef372e61e3..d49b85887b6f 100644 --- a/front/pages/w/[wId]/builder/assistants/new.tsx +++ b/front/pages/w/[wId]/builder/assistants/new.tsx @@ -8,9 +8,9 @@ import type { WorkspaceType, } from "@dust-tt/types"; import { - isDatabaseQueryConfiguration, isDustAppRunConfiguration, isRetrievalConfiguration, + isTablesQueryConfiguration, } from "@dust-tt/types"; import type { GetServerSideProps, InferGetServerSidePropsType } from "next"; @@ -43,7 +43,7 @@ export const getServerSideProps: GetServerSideProps<{ dataSourceConfigurations: Record | null; dustApps: AppType[]; dustAppConfiguration: AssistantBuilderInitialState["dustAppConfiguration"]; - databaseQueryConfiguration: AssistantBuilderInitialState["databaseQueryConfiguration"]; + tablesQueryConfiguration: AssistantBuilderInitialState["tablesQueryConfiguration"]; agentConfiguration: AgentConfigurationType | null; flow: BuilderFlow; }> = async (context) => { @@ -91,7 +91,7 @@ export const getServerSideProps: GetServerSideProps<{ const { dataSourceConfigurations, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, } = config ? await buildInitialState({ config, @@ -101,7 +101,7 @@ export const getServerSideProps: GetServerSideProps<{ : { dataSourceConfigurations: null, dustAppConfiguration: null, - databaseQueryConfiguration: null, + tablesQueryConfiguration: {}, }; return { @@ -115,7 +115,7 @@ export const getServerSideProps: GetServerSideProps<{ dataSourceConfigurations, dustApps: allDustApps, dustAppConfiguration, - databaseQueryConfiguration: databaseQueryConfiguration, + tablesQueryConfiguration, agentConfiguration: config, flow, }, @@ -132,7 +132,7 @@ export default function CreateAssistant({ dataSourceConfigurations, dustApps, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, agentConfiguration, flow, }: InferGetServerSidePropsType) { @@ -167,8 +167,8 @@ export default function CreateAssistant({ actionMode = "DUST_APP_RUN"; } - if (isDatabaseQueryConfiguration(agentConfiguration.action)) { - actionMode = "DATABASE_QUERY"; + if (isTablesQueryConfiguration(agentConfiguration.action)) { + actionMode = "TABLES_QUERY"; } if (agentConfiguration.scope === "global") { throw new Error("Cannot edit global assistant"); @@ -192,7 +192,7 @@ export default function CreateAssistant({ timeFrame, dataSourceConfigurations, dustAppConfiguration, - databaseQueryConfiguration, + tablesQueryConfiguration, scope: "private", handle: `${agentConfiguration.name}_Copy`, description: agentConfiguration.description, diff --git a/types/src/front/api_handlers/internal/agent_configuration.ts b/types/src/front/api_handlers/internal/agent_configuration.ts index 28f70d5768f8..7f7f99563556 100644 --- a/types/src/front/api_handlers/internal/agent_configuration.ts +++ b/types/src/front/api_handlers/internal/agent_configuration.ts @@ -91,10 +91,14 @@ export const PostOrPatchAgentConfigurationRequestBodySchema = t.type({ appId: t.string, }), t.type({ - type: t.literal("database_query_configuration"), - dataSourceWorkspaceId: t.string, - dataSourceId: t.string, - databaseId: t.string, + type: t.literal("tables_query_configuration"), + tables: t.array( + t.type({ + workspaceId: t.string, + dataSourceId: t.string, + tableId: t.string, + }) + ), }), ]), generation: t.union([ diff --git a/types/src/front/assistant/actions/database_query.ts b/types/src/front/assistant/actions/database_query.ts deleted file mode 100644 index 62ee208e16d7..000000000000 --- a/types/src/front/assistant/actions/database_query.ts +++ /dev/null @@ -1,41 +0,0 @@ -import { ModelId } from "shared/model_id"; - -import { AgentActionConfigurationType } from "../../../front/assistant/agent"; -import { AgentActionType } from "../../../front/assistant/conversation"; - -export type DatabaseQueryConfigurationType = { - id: ModelId; - sId: string; - type: "database_query_configuration"; - dataSourceWorkspaceId: string; - dataSourceId: string; - databaseId: string; -}; - -export function isDatabaseQueryConfiguration( - arg: AgentActionConfigurationType | null -): arg is DatabaseQueryConfigurationType { - return ( - arg !== null && arg.type && arg.type === "database_query_configuration" - ); -} - -export function isDatabaseQueryActionType( - arg: AgentActionType -): arg is DatabaseQueryActionType { - return arg.type === "database_query_action"; -} - -export type DatabaseQueryActionType = { - id: ModelId; - type: "database_query_action"; - dataSourceWorkspaceId: string; - dataSourceId: string; - databaseId: string; - params: { - [key: string]: string | number | boolean; - }; - output: { - [key: string]: string | number | boolean; - } | null; -}; diff --git a/types/src/front/assistant/actions/dust_app_run.ts b/types/src/front/assistant/actions/dust_app_run.ts index 767367693ed1..fd18c464b05e 100644 --- a/types/src/front/assistant/actions/dust_app_run.ts +++ b/types/src/front/assistant/actions/dust_app_run.ts @@ -1,5 +1,3 @@ -import { AgentActionConfigurationType } from "../../../front/assistant/agent"; -import { AgentActionType } from "../../../front/assistant/conversation"; import { ModelId } from "../../../shared/model_id"; export type DustAppRunConfigurationType = { @@ -12,18 +10,6 @@ export type DustAppRunConfigurationType = { appId: string; }; -export function isDustAppRunConfiguration( - arg: AgentActionConfigurationType | null -): arg is DustAppRunConfigurationType { - return arg !== null && arg.type && arg.type === "dust_app_run_configuration"; -} - -export function isDustAppRunActionType( - arg: AgentActionType -): arg is DustAppRunActionType { - return arg.type === "dust_app_run_action"; -} - export type DustAppParameters = { [key: string]: string | number | boolean; }; diff --git a/types/src/front/assistant/actions/guards.ts b/types/src/front/assistant/actions/guards.ts new file mode 100644 index 000000000000..acf727730539 --- /dev/null +++ b/types/src/front/assistant/actions/guards.ts @@ -0,0 +1,50 @@ +import { + DustAppRunActionType, + DustAppRunConfigurationType, +} from "../../../front/assistant/actions/dust_app_run"; +import { AgentActionConfigurationType } from "../../../front/assistant/agent"; +import { AgentActionType } from "../../../front/assistant/conversation"; +import { + RetrievalActionType, + RetrievalConfigurationType, +} from "../../assistant/actions/retrieval"; +import { + TablesQueryActionType, + TablesQueryConfigurationType, +} from "../../assistant/actions/tables_query"; + +export function isTablesQueryConfiguration( + arg: AgentActionConfigurationType | null +): arg is TablesQueryConfigurationType { + return arg?.type === "tables_query_configuration"; +} + +export function isTablesQueryActionType( + arg: AgentActionType +): arg is TablesQueryActionType { + return arg.type === "tables_query_action"; +} + +export function isDustAppRunConfiguration( + arg: AgentActionConfigurationType | null +): arg is DustAppRunConfigurationType { + return arg !== null && arg.type && arg.type === "dust_app_run_configuration"; +} + +export function isDustAppRunActionType( + arg: AgentActionType +): arg is DustAppRunActionType { + return arg.type === "dust_app_run_action"; +} + +export function isRetrievalConfiguration( + arg: AgentActionConfigurationType | null +): arg is RetrievalConfigurationType { + return arg !== null && arg.type && arg.type === "retrieval_configuration"; +} + +export function isRetrievalActionType( + arg: AgentActionType +): arg is RetrievalActionType { + return arg.type === "retrieval_action"; +} diff --git a/types/src/front/assistant/actions/retrieval.ts b/types/src/front/assistant/actions/retrieval.ts index ec0e9d02db1e..d6cde044a043 100644 --- a/types/src/front/assistant/actions/retrieval.ts +++ b/types/src/front/assistant/actions/retrieval.ts @@ -2,8 +2,6 @@ * Data Source configuration */ -import { AgentActionConfigurationType } from "../../../front/assistant/agent"; -import { AgentActionType } from "../../../front/assistant/conversation"; import { ModelId } from "../../../shared/model_id"; import { ioTsEnum } from "../../../shared/utils/iots_utils"; @@ -74,12 +72,6 @@ export type RetrievalConfigurationType = { // autoSkip: boolean; }; -export function isRetrievalConfiguration( - arg: AgentActionConfigurationType | null -): arg is RetrievalConfigurationType { - return arg !== null && arg.type && arg.type === "retrieval_configuration"; -} - /** * Retrieval action */ @@ -101,12 +93,6 @@ export type RetrievalDocumentType = { }[]; }; -export function isRetrievalActionType( - arg: AgentActionType -): arg is RetrievalActionType { - return arg.type === "retrieval_action"; -} - export type RetrievalActionType = { id: ModelId; // AgentRetrieval. type: "retrieval_action"; diff --git a/types/src/front/assistant/actions/tables_query.ts b/types/src/front/assistant/actions/tables_query.ts new file mode 100644 index 000000000000..4ca06caf81da --- /dev/null +++ b/types/src/front/assistant/actions/tables_query.ts @@ -0,0 +1,22 @@ +import { DustAppParameters } from "front/assistant/actions/dust_app_run"; + +import { ModelId } from "../../../shared/model_id"; + +export type TablesQueryConfigurationType = { + id: ModelId; + sId: string; + type: "tables_query_configuration"; + tables: Array<{ + workspaceId: string; + dataSourceId: string; + tableId: string; + }>; +}; + +export type TablesQueryActionType = { + id: ModelId; + type: "tables_query_action"; + + params: DustAppParameters; + output: Record | null; +}; diff --git a/types/src/front/assistant/agent.ts b/types/src/front/assistant/agent.ts index 13c10a9690d9..d30b054e09c8 100644 --- a/types/src/front/assistant/agent.ts +++ b/types/src/front/assistant/agent.ts @@ -1,6 +1,6 @@ -import { DatabaseQueryConfigurationType } from "../../front/assistant/actions/database_query"; import { DustAppRunConfigurationType } from "../../front/assistant/actions/dust_app_run"; import { RetrievalConfigurationType } from "../../front/assistant/actions/retrieval"; +import { TablesQueryConfigurationType } from "../../front/assistant/actions/tables_query"; import { SupportedModel } from "../../front/lib/assistant"; import { ModelId } from "../../shared/model_id"; @@ -12,9 +12,9 @@ import { ModelId } from "../../shared/model_id"; // - Add the type to the union type below // - Add model rendering support in `renderConversationForModel` export type AgentActionConfigurationType = + | TablesQueryConfigurationType | RetrievalConfigurationType - | DustAppRunConfigurationType - | DatabaseQueryConfigurationType; + | DustAppRunConfigurationType; // Each AgentActionConfigurationType is capable of generating this type at runtime to specify which // inputs should be generated by the model. As an example, to run the retrieval action for which the diff --git a/types/src/front/assistant/conversation.ts b/types/src/front/assistant/conversation.ts index f45772b92305..00a03dcec36c 100644 --- a/types/src/front/assistant/conversation.ts +++ b/types/src/front/assistant/conversation.ts @@ -1,6 +1,6 @@ -import { DatabaseQueryActionType } from "../../front/assistant/actions/database_query"; import { DustAppRunActionType } from "../../front/assistant/actions/dust_app_run"; import { RetrievalActionType } from "../../front/assistant/actions/retrieval"; +import { TablesQueryActionType } from "../../front/assistant/actions/tables_query"; import { LightAgentConfigurationType } from "../../front/assistant/agent"; import { UserType, WorkspaceType } from "../../front/user"; import { ModelId } from "../../shared/model_id"; @@ -86,7 +86,8 @@ export function isUserMessageType( export type AgentActionType = | RetrievalActionType | DustAppRunActionType - | DatabaseQueryActionType; + | TablesQueryActionType; + export type AgentMessageStatus = | "created" | "succeeded" diff --git a/types/src/front/lib/actions/registry.ts b/types/src/front/lib/actions/registry.ts index 83d0508d34ee..ab225aa4a27d 100644 --- a/types/src/front/lib/actions/registry.ts +++ b/types/src/front/lib/actions/registry.ts @@ -133,12 +133,12 @@ export const DustProdActionRegistry = createActionRegistry({ }, }, }, - "assistant-v2-query-database": { + "assistant-v2-query-tables": { app: { workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "e0c5993d65", + appId: "b4f205e453", appHash: - "b83bf8e8fab6901d805e921db31ecf8f23af608972fa8c6372c6d599e063d90c", + "6e94a4b27c7243e1f8cd49617b8b76679a8462001ba2752b575a0f8bc1390d17", }, config: { MODEL: { diff --git a/types/src/front/lib/api/assistant/actions/database_query.ts b/types/src/front/lib/api/assistant/actions/database_query.ts deleted file mode 100644 index 3586f7f7a891..000000000000 --- a/types/src/front/lib/api/assistant/actions/database_query.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { DatabaseQueryActionType } from "front/assistant/actions/database_query"; - -export type DatabaseQueryErrorEvent = { - type: "database_query_error"; - created: number; - configurationId: string; - messageId: string; - error: { - code: string; - message: string; - }; -}; - -export type DatabaseQuerySuccessEvent = { - type: "database_query_success"; - created: number; - configurationId: string; - messageId: string; - action: DatabaseQueryActionType; -}; - -export type DatabaseQueryParamsEvent = { - type: "database_query_params"; - created: number; - configurationId: string; - messageId: string; - action: DatabaseQueryActionType; -}; - -export type DatabaseQueryOutputEvent = { - type: "database_query_output"; - created: number; - configurationId: string; - messageId: string; - action: DatabaseQueryActionType; -}; diff --git a/types/src/front/lib/api/assistant/actions/tables_query.ts b/types/src/front/lib/api/assistant/actions/tables_query.ts new file mode 100644 index 000000000000..4a201aa9afda --- /dev/null +++ b/types/src/front/lib/api/assistant/actions/tables_query.ts @@ -0,0 +1,36 @@ +import { TablesQueryActionType } from "../../../../../front/assistant/actions/tables_query"; + +export type TablesQueryErrorEvent = { + type: "tables_query_error"; + created: number; + configurationId: string; + messageId: string; + error: { + code: string; + message: string; + }; +}; + +export type TablesQuerySuccessEvent = { + type: "tables_query_success"; + created: number; + configurationId: string; + messageId: string; + action: TablesQueryActionType; +}; + +export type TablesQueryParamsEvent = { + type: "tables_query_params"; + created: number; + configurationId: string; + messageId: string; + action: TablesQueryActionType; +}; + +export type TablesQueryOutputEvent = { + type: "tables_query_output"; + created: number; + configurationId: string; + messageId: string; + action: TablesQueryActionType; +}; diff --git a/types/src/front/lib/api/assistant/agent.ts b/types/src/front/lib/api/assistant/agent.ts index 93528860765b..6c4a05a361af 100644 --- a/types/src/front/lib/api/assistant/agent.ts +++ b/types/src/front/lib/api/assistant/agent.ts @@ -6,15 +6,15 @@ import { AgentActionType, AgentMessageType, } from "../../../../front/assistant/conversation"; -import { - DatabaseQueryOutputEvent, - DatabaseQueryParamsEvent, -} from "../../../../front/lib/api/assistant/actions/database_query"; import { DustAppRunBlockEvent, DustAppRunParamsEvent, } from "../../../../front/lib/api/assistant/actions/dust_app_run"; import { RetrievalParamsEvent } from "../../../../front/lib/api/assistant/actions/retrieval"; +import { + TablesQueryOutputEvent, + TablesQueryParamsEvent, +} from "../../../../front/lib/api/assistant/actions/tables_query"; // Event sent when an agent error occured before we have a agent message in the database. export type AgentMessageErrorEvent = { @@ -44,8 +44,8 @@ export type AgentActionEvent = | RetrievalParamsEvent | DustAppRunParamsEvent | DustAppRunBlockEvent - | DatabaseQueryParamsEvent - | DatabaseQueryOutputEvent; + | TablesQueryParamsEvent + | TablesQueryOutputEvent; // Event sent once the action is completed, we're moving to generating a message if applicable. export type AgentActionSuccessEvent = { diff --git a/types/src/index.ts b/types/src/index.ts index 5a6cc2a179c0..bfd4f529bedc 100644 --- a/types/src/index.ts +++ b/types/src/index.ts @@ -5,9 +5,10 @@ export * from "./front/api_handlers/internal/assistant"; export * from "./front/api_handlers/public/assistant"; export * from "./front/api_handlers/public/data_sources"; export * from "./front/app"; -export * from "./front/assistant/actions/database_query"; export * from "./front/assistant/actions/dust_app_run"; +export * from "./front/assistant/actions/guards"; export * from "./front/assistant/actions/retrieval"; +export * from "./front/assistant/actions/tables_query"; export * from "./front/assistant/agent"; export * from "./front/assistant/conversation"; export * from "./front/data_source"; @@ -17,9 +18,9 @@ export * from "./front/extract"; export * from "./front/key"; export * from "./front/lib/actions/registry"; export * from "./front/lib/actions/types"; -export * from "./front/lib/api/assistant/actions/database_query"; export * from "./front/lib/api/assistant/actions/dust_app_run"; export * from "./front/lib/api/assistant/actions/retrieval"; +export * from "./front/lib/api/assistant/actions/tables_query"; export * from "./front/lib/api/assistant/agent"; export * from "./front/lib/api/assistant/conversation"; export * from "./front/lib/api/assistant/generation";