diff --git a/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_retrieval.ts b/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_retrieval.ts index 3b25db694ab25..526e43c30c256 100644 --- a/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_retrieval.ts +++ b/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_retrieval.ts @@ -69,6 +69,7 @@ const DocTrackerRetrievalActionValueSchema = t.array( created: t.Integer, document_id: t.string, timestamp: t.Integer, + title: t.union([t.string, t.null]), tags: t.array(t.string), parents: t.array(t.string), source_url: t.union([t.string, t.null]), @@ -77,12 +78,17 @@ const DocTrackerRetrievalActionValueSchema = t.array( text: t.union([t.string, t.null, t.undefined]), chunk_count: t.Integer, chunks: t.array( - t.type({ - text: t.string, - hash: t.string, - offset: t.Integer, - score: t.number, - }) + t.intersection([ + t.type({ + text: t.string, + hash: t.string, + offset: t.Integer, + score: t.number, + }), + t.partial({ + expanded_offsets: t.array(t.Integer), + }), + ]) ), token_count: t.Integer, }) diff --git a/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_score_docs.ts b/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_score_docs.ts new file mode 100644 index 0000000000000..23b1b47af0303 --- /dev/null +++ b/front/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_score_docs.ts @@ -0,0 +1,62 @@ +import * as t from "io-ts"; + +import { callAction } from "@app/lib/actions/helpers"; +import type { Authenticator } from "@app/lib/auth"; +import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry"; + +export async function callDocTrackerScoreDocsAction( + auth: Authenticator, + { + watchedDocDiff, + maintainedDocuments, + providerId, + modelId, + }: { + watchedDocDiff: string; + maintainedDocuments: Array<{ + content: string; + title: string | null; + sourceUrl: string | null; + dataSourceId: string; + documentId: string; + }>; + providerId: string; + modelId: string; + } +): Promise { + const action = DustProdActionRegistry["doc-tracker-score-docs"]; + + const config = cloneBaseConfig(action.config); + config.SUGGEST_CHANGES.provider_id = providerId; + config.SUGGEST_CHANGES.model_id = modelId; + + const res = await callAction(auth, { + action, + config, + input: { + watched_diff: watchedDocDiff, + maintained_documents: maintainedDocuments, + }, + responseValueSchema: DocTrackerScoreDocsActionResultSchema, + }); + + if (res.isErr()) { + throw res.error; + } + + return res.value; +} + +const DocTrackerScoreDocsActionResultSchema = t.array( + t.type({ + documentId: t.string, + dataSourceId: t.string, + score: t.number, + title: t.union([t.string, t.null, t.undefined]), + sourceUrl: t.union([t.string, t.null, t.undefined]), + }) +); + +type DocTrackerScoreDocsActionResult = t.TypeOf< + typeof DocTrackerScoreDocsActionResultSchema +>; diff --git a/front/lib/registry.ts b/front/lib/registry.ts index b51ae750152c3..ed569a66bb42c 100644 --- a/front/lib/registry.ts +++ b/front/lib/registry.ts @@ -136,6 +136,20 @@ export const DustProdActionRegistry = createActionRegistry({ }, }, }, + "doc-tracker-score-docs": { + app: { + workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, + appId: "N0RrhyTXfq", + appHash: + "af4ab848b1e4f13afffdf2a9672bb2613e278ed5ee55a3a1c67a013e7036daee", + appSpaceId: PRODUCTION_DUST_APPS_SPACE_ID, + }, + config: { + MODEL: { + use_cache: false, + }, + }, + }, "doc-tracker-suggest-changes": { app: { workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, diff --git a/front/temporal/tracker/activities.ts b/front/temporal/tracker/activities.ts index 16850c9683c5d..e5a8e79a6d80d 100644 --- a/front/temporal/tracker/activities.ts +++ b/front/temporal/tracker/activities.ts @@ -13,6 +13,7 @@ import { processTrackerNotification } from "@app/lib/api/tracker"; import { Authenticator } from "@app/lib/auth"; import { getDocumentDiff } from "@app/lib/document_upsert_hooks/hooks/data_source_helpers"; import { callDocTrackerRetrievalAction } from "@app/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_retrieval"; +import { callDocTrackerScoreDocsAction } from "@app/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_score_docs"; import { callDocTrackerSuggestChangesAction } from "@app/lib/document_upsert_hooks/hooks/tracker/actions/doc_tracker_suggest_changes"; import { Workspace } from "@app/lib/models/workspace"; import { DataSourceResource } from "@app/lib/resources/data_source_resource"; @@ -28,10 +29,12 @@ const TRACKER_WATCHED_DOCUMENT_MINIMUM_DIFF_LINE_LENGTH = 4; const TRACKER_WATCHED_DOCUMENT_MAX_DIFF_TOKENS = 4096; // The total number of tokens to show to the model (watched doc diff + maintained scope retrieved tokens) const TRACKER_TOTAL_TARGET_TOKENS = 8192; -// The topK used for the semantic search against the maintained scope. -// TODO(DOC_TRACKER): Decide how we handle this. If the top doc has less than $targetDocumentTokens, -// we could include content from the next doc in the maintained scope. -const TRACKER_MAINTAINED_DOCUMENT_TOP_K = 1; +// The maximum number of chunks to retrieve from the maintained scope. +const TRACKER_MAINTAINED_SCOPE_MAX_TOP_K = 8; + +// The size of the chunks in our data sources. +// TODO(@fontanierh): find a way to ensure this remains true. +const CHUNK_SIZE = 512; export async function getDebounceMsActivity( dataSourceConnectorProvider: ConnectorProvider | null @@ -185,6 +188,20 @@ export async function trackersGenerationActivity( const targetMaintainedScopeTokens = TRACKER_TOTAL_TARGET_TOKENS - tokensInDiffCount; + // We don't want to retrieve more than targetMaintainedScopeTokens / CHUNK_SIZE chunks, + // in case all retrieved chunks are from the same document (in which case, we'd have + // more than targetMaintainedScopeTokens tokens for that document). + const maintainedScopeTopK = Math.min( + TRACKER_MAINTAINED_SCOPE_MAX_TOP_K, + Math.floor(targetMaintainedScopeTokens / CHUNK_SIZE) + ); + + if (maintainedScopeTopK === 0) { + throw new Error( + "Unreachable: targetMaintainedScopeTokens is less than CHUNK_SIZE." + ); + } + for (const tracker of trackers) { const trackerLogger = localLogger.child({ trackerId: tracker.sId, @@ -216,60 +233,149 @@ export async function trackersGenerationActivity( const maintainedScopeRetrieval = await callDocTrackerRetrievalAction(auth, { inputText: diffString, targetDocumentTokens: targetMaintainedScopeTokens, - topK: TRACKER_MAINTAINED_DOCUMENT_TOP_K, + topK: maintainedScopeTopK, maintainedScope, parentsInMap, }); - // TODO(DOC_TRACKER): Right now we only handle the top match. - // We may want to support topK > 1 and process more than 1 doc if the top doc has less than - // $targetDocumentTokens. if (maintainedScopeRetrieval.length === 0) { trackerLogger.info("No content retrieved from maintained scope."); continue; } - const content = maintainedScopeRetrieval[0].chunks - .map((c) => c.text) - .join("\n"); - if (!content) { - trackerLogger.info("No content retrieved from maintained scope."); - continue; - } - - const suggestChangesResult = await callDocTrackerSuggestChangesAction( - auth, - { - watchedDocDiff: diffString, - maintainedDocContent: content, - prompt: tracker.prompt, - providerId: tracker.providerId, - modelId: tracker.modelId, + const maintainedDocuments: { + content: string; + sourceUrl: string | null; + title: string | null; + dataSourceId: string; + documentId: string; + }[] = []; + + for (const retrievalDoc of maintainedScopeRetrieval) { + let docContent: string = ""; + const sortedChunks = _.sortBy(retrievalDoc.chunks, (c) => c.offset); + + for (const [i, chunk] of sortedChunks.entries()) { + if (i === 0) { + // If we are at index 0 (i.e the first retrieved chunk), we check whether our chunk includes + // the beginning of the document. If it doesn't, we add a "[...]"" separator. + const allOffsetsInChunk = [ + chunk.offset, + ...(chunk.expanded_offsets ?? []), + ]; + const isBeginningOfDocument = allOffsetsInChunk.includes(0); + if (!isBeginningOfDocument) { + docContent += "[...]\n"; + } + } else { + // If we are not at index 0, we check whether the current chunk is a direct continuation of the previous chunk. + // We do this by checking that the first offset of the current chunk is the last offset of the previous chunk + 1. + const previousChunk = sortedChunks[i - 1]; + const allOffsetsInCurrentChunk = [ + chunk.offset, + ...(chunk.expanded_offsets ?? []), + ]; + const firstOffsetInCurrentChunk = _.min(allOffsetsInCurrentChunk)!; + const allOffsetsInPreviousChunk = [ + previousChunk.offset, + ...(previousChunk.expanded_offsets ?? []), + ]; + const lastOffsetInPreviousChunk = _.max(allOffsetsInPreviousChunk)!; + const hasGap = + firstOffsetInCurrentChunk !== lastOffsetInPreviousChunk + 1; + + if (hasGap) { + docContent += "[...]\n"; + } + } + + // Add the chunk text to the document. + docContent += chunk.text + "\n"; + + if (i === sortedChunks.length - 1) { + // If we are at the last chunk, we check if we have the last offset of the doc. + // If not, we add a "[...]" separator. + const lastChunk = sortedChunks[sortedChunks.length - 1]; + if (lastChunk.offset !== retrievalDoc.chunk_count - 1) { + docContent += "[...]\n"; + } + } } - ); - if (!suggestChangesResult.suggestion) { - trackerLogger.info("No changes suggested."); - continue; + maintainedDocuments.push({ + content: docContent, + sourceUrl: retrievalDoc.source_url, + title: retrievalDoc.title, + dataSourceId: retrievalDoc.data_source_id, + documentId: retrievalDoc.document_id, + }); } - const suggestedChanges = suggestChangesResult.suggestion; - const thinking = suggestChangesResult.thinking; - const confidenceScore = suggestChangesResult.confidence_score; - - trackerLogger.info( - { - confidenceScore, - }, - "Changes suggested." + const contentByDocumentIdentifier = _.mapValues( + _.keyBy( + maintainedDocuments, + (doc) => `${doc.dataSourceId}__${doc.documentId}` + ), + (doc) => doc.content ); - await tracker.addGeneration({ - generation: suggestedChanges, - thinking: thinking ?? null, - dataSourceId, - documentId, + const scoreDocsResult = await callDocTrackerScoreDocsAction(auth, { + watchedDocDiff: diffString, + maintainedDocuments, + providerId: tracker.providerId, + modelId: tracker.modelId, }); + + for (const { documentId, dataSourceId, score } of scoreDocsResult) { + logger.info( + { + documentId, + dataSourceId, + score, + }, + "Running document tracker suggest changes." + ); + + const content = + contentByDocumentIdentifier[`${dataSourceId}__${documentId}`]; + if (!content) { + continue; + } + + const suggestChangesResult = await callDocTrackerSuggestChangesAction( + auth, + { + watchedDocDiff: diffString, + maintainedDocContent: content, + prompt: tracker.prompt, + providerId: tracker.providerId, + modelId: tracker.modelId, + } + ); + + if (!suggestChangesResult.suggestion) { + trackerLogger.info("No changes suggested."); + continue; + } + + const suggestedChanges = suggestChangesResult.suggestion; + const thinking = suggestChangesResult.thinking; + const confidenceScore = suggestChangesResult.confidence_score; + + trackerLogger.info( + { + confidenceScore, + }, + "Changes suggested." + ); + + await tracker.addGeneration({ + generation: suggestedChanges, + thinking: thinking ?? null, + dataSourceId, + documentId, + }); + } } }