diff --git a/package.json b/package.json index 4ed04d2ae..c385b4e1d 100644 --- a/package.json +++ b/package.json @@ -87,9 +87,11 @@ "id": "mongodb.participant", "name": "MongoDB", "description": "Ask anything about MongoDB, from writing queries to questions about your cluster.", + "isSticky": true, "commands": [ { "name": "query", + "isSticky": true, "description": "Ask how to write MongoDB queries or pipelines. For example, you can ask: \"Show me all the documents where the address contains the word street\"." } ] @@ -161,13 +163,25 @@ } ], "commands": [ + { + "command": "mdb.selectDatabaseWithParticipant", + "title": "MongoDB: Select Database with Participant" + }, + { + "command": "mdb.selectCollectionWithParticipant", + "title": "MongoDB: Select Collection with Participant" + }, + { + "command": "mdb.connectWithParticipant", + "title": "MongoDB: Change Active Connection with Participant" + }, { "command": "mdb.runParticipantQuery", - "title": "Run Content Generated by the Chat Participant" + "title": "Run Content Generated by Participant" }, { "command": "mdb.openParticipantQueryInPlayground", - "title": "Open Generated by the Chat Participant Content In Playground" + "title": "Open Generated by Participant Content In Playground" }, { "command": "mdb.connect", @@ -719,6 +733,22 @@ } ], "commandPalette": [ + { + "command": "mdb.selectDatabaseWithParticipant", + "when": "false" + }, + { + "command": "mdb.selectCollectionWithParticipant", + "when": "false" + }, + { + "command": "mdb.connectWithParticipant", + "when": "false" + }, + { + "command": "mdb.runParticipantQuery", + "when": "false" + }, { "command": "mdb.openParticipantQueryInPlayground", "when": "false" diff --git a/src/commands/index.ts b/src/commands/index.ts index ba4c32cde..2a11bf67e 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -79,6 +79,9 @@ enum EXTENSION_COMMANDS { // Chat participant. OPEN_PARTICIPANT_QUERY_IN_PLAYGROUND = 'mdb.openParticipantQueryInPlayground', RUN_PARTICIPANT_QUERY = 'mdb.runParticipantQuery', + CONNECT_WITH_PARTICIPANT = 'mdb.connectWithParticipant', + SELECT_DATABASE_WITH_PARTICIPANT = 'mdb.selectDatabaseWithParticipant', + SELECT_COLLECTION_WITH_PARTICIPANT = 'mdb.selectCollectionWithParticipant', } export default EXTENSION_COMMANDS; diff --git a/src/connectionController.ts b/src/connectionController.ts index 4b4515b8a..b25da9201 100644 --- a/src/connectionController.ts +++ b/src/connectionController.ts @@ -419,7 +419,14 @@ export default class ConnectionController { } log.info('Successfully connected', { connectionId }); - void vscode.window.showInformationMessage('MongoDB connection successful.'); + + const message = 'MongoDB connection successful.'; + this._statusView.showMessage(message); + setTimeout(() => { + if (this._statusView._statusBarItem.text === message) { + this._statusView.hideMessage(); + } + }, 5000); dataService.addReauthenticationHandler( this._reauthenticationHandler.bind(this) @@ -603,10 +610,16 @@ export default class ConnectionController { 'mdb.isAtlasStreams', false ); - void vscode.window.showInformationMessage('MongoDB disconnected.'); this._disconnecting = false; - this._statusView.hideMessage(); + + const message = 'MongoDB disconnected.'; + this._statusView.showMessage(message); + setTimeout(() => { + if (this._statusView._statusBarItem.text === message) { + this._statusView.hideMessage(); + } + }, 5000); return true; } diff --git a/src/editors/mongoDBDocumentService.ts b/src/editors/mongoDBDocumentService.ts index 928d02de0..1806cddda 100644 --- a/src/editors/mongoDBDocumentService.ts +++ b/src/editors/mongoDBDocumentService.ts @@ -98,13 +98,11 @@ export default class MongoDBDocumentService { returnDocument: 'after', } ); - - this._statusView.hideMessage(); this._telemetryService.trackDocumentUpdated(source, true); } catch (error) { - this._statusView.hideMessage(); - return this._saveDocumentFailed(formatError(error).message); + } finally { + this._statusView.hideMessage(); } } @@ -141,17 +139,15 @@ export default class MongoDBDocumentService { { limit: 1 } ); - this._statusView.hideMessage(); - if (!documents || documents.length === 0) { return; } return getEJSON(documents[0]); } catch (error) { - this._statusView.hideMessage(); - return this._fetchDocumentFailed(formatError(error).message); + } finally { + this._statusView.hideMessage(); } } } diff --git a/src/mdbExtensionController.ts b/src/mdbExtensionController.ts index cc1380f5b..008d46002 100644 --- a/src/mdbExtensionController.ts +++ b/src/mdbExtensionController.ts @@ -292,7 +292,8 @@ export default class MDBExtensionController implements vscode.Disposable { () => { return this._playgroundController.createPlaygroundFromParticipantQuery({ text: - this._participantController._chatResult.metadata.queryContent || '', + this._participantController._chatResult?.metadata + ?.responseContent || '', }); } ); @@ -301,10 +302,25 @@ export default class MDBExtensionController implements vscode.Disposable { () => { return this._playgroundController.evaluateParticipantQuery({ text: - this._participantController._chatResult.metadata.queryContent || '', + this._participantController._chatResult?.metadata + ?.responseContent || '', }); } ); + this.registerCommand( + EXTENSION_COMMANDS.CONNECT_WITH_PARTICIPANT, + (id: string) => this._participantController.connectWithParticipant(id) + ); + this.registerCommand( + EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, + (name: string) => + this._participantController.selectDatabaseWithParticipant(name) + ); + this.registerCommand( + EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, + (name: string) => + this._participantController.selectCollectionWithParticipant(name) + ); }; registerParticipantCommand = ( diff --git a/src/participant/participant.ts b/src/participant/participant.ts index ded24c582..d89ade7a5 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1,48 +1,71 @@ import * as vscode from 'vscode'; +import type { DataService } from 'mongodb-data-service'; import { createLogger } from '../logging'; import type ConnectionController from '../connectionController'; +import type { LoadedConnection } from '../storage/connectionStorage'; import EXTENSION_COMMANDS from '../commands'; import type { StorageController } from '../storage'; import { StorageVariables } from '../storage'; import { GenericPrompt } from './prompts/generic'; import { CHAT_PARTICIPANT_ID } from './constants'; import { QueryPrompt } from './prompts/query'; +import { NamespacePrompt } from './prompts/namespace'; const log = createLogger('participant'); +enum QUERY_GENERATION_STATE { + DEFAULT = 'DEFAULT', + ASK_TO_CONNECT = 'ASK_TO_CONNECT', + ASK_FOR_DATABASE_NAME = 'ASK_FOR_DATABASE_NAME', + ASK_FOR_COLLECTION_NAME = 'ASK_FOR_COLLECTION_NAME', + READY_TO_GENERATE_QUERY = 'READY_TO_GENERATE_QUERY', +} + interface ChatResult extends vscode.ChatResult { metadata: { - command: string; - databaseName?: string; - collectionName?: string; - queryContent?: string; - description?: string; + responseContent?: string; }; - stream?: vscode.ChatResponseStream; } export const CHAT_PARTICIPANT_MODEL = 'gpt-4o'; -export function getRunnableContentFromString(responseContent: string) { - const matchedJSQueryContent = responseContent.match( - /```javascript((.|\n)*)```/ - ); - log.info('matchedJSQueryContent', matchedJSQueryContent); +const DB_NAME_ID = 'DATABASE_NAME'; +const DB_NAME_REGEX = `${DB_NAME_ID}: (.*)\n`; + +const COL_NAME_ID = 'COLLECTION_NAME'; +const COL_NAME_REGEX = `${COL_NAME_ID}: (.*)`; + +function parseForDatabaseAndCollectionName(text: string): { + databaseName?: string; + collectionName?: string; +} { + const databaseName = text.match(DB_NAME_REGEX)?.[1]; + const collectionName = text.match(COL_NAME_REGEX)?.[1]; + + return { databaseName, collectionName }; +} + +export function getRunnableContentFromString(text: string) { + const matchedJSresponseContent = text.match(/```javascript((.|\n)*)```/); + log.info('matchedJSresponseContent', matchedJSresponseContent); - const queryContent = - matchedJSQueryContent && matchedJSQueryContent.length > 1 - ? matchedJSQueryContent[1] + const responseContent = + matchedJSresponseContent && matchedJSresponseContent.length > 1 + ? matchedJSresponseContent[1] : ''; - log.info('queryContent', queryContent); - return queryContent; + log.info('responseContent', responseContent); + return responseContent; } export class ParticipantController { _participant?: vscode.ChatParticipant; - _chatResult: ChatResult; _connectionController: ConnectionController; _storageController: StorageController; + _queryGenerationState?: QUERY_GENERATION_STATE; + _chatResult?: ChatResult; + _databaseName?: string; + _collectionName?: string; constructor({ connectionController, @@ -51,7 +74,6 @@ export class ParticipantController { connectionController: ConnectionController; storageController: StorageController; }) { - this._chatResult = { metadata: { command: '' } }; this._connectionController = connectionController; this._storageController = storageController; } @@ -79,16 +101,27 @@ export class ParticipantController { return this._participant || this.createParticipant(context); } - handleEmptyQueryRequest(): ChatResult { - return { - metadata: { - command: '', - }, - errorDetails: { - message: - 'Please specify a question when using this command. Usage: @MongoDB /query find documents where "name" contains "database".', - }, - }; + handleEmptyQueryRequest(stream: vscode.ChatResponseStream): undefined { + let message; + switch (this._queryGenerationState) { + case QUERY_GENERATION_STATE.ASK_TO_CONNECT: + message = + 'Please select a cluster to connect by clicking on an item in the connections list.'; + break; + case QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME: + message = + 'Please select a database by either clicking on an item in the list or typing the name manually in the chat.'; + break; + case QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME: + message = + 'Please select a collection by either clicking on an item in the list or typing the name manually in the chat.'; + break; + default: + message = + 'Please specify a question when using this command. Usage: @MongoDB /query find documents where "name" contains "database".'; + } + stream.markdown(vscode.l10n.t(`${message}\n\n`)); + return; } handleError(err: any, stream: vscode.ChatResponseStream): void { @@ -133,9 +166,7 @@ export class ParticipantController { const chatResponse = await model.sendRequest(messages, {}, token); for await (const fragment of chatResponse.text) { responseContent += fragment; - stream.markdown(fragment); } - stream.markdown('\n\n'); } } catch (err) { this.handleError(err, stream); @@ -145,17 +176,12 @@ export class ParticipantController { } // @MongoDB what is mongodb? - async handleGenericRequest({ - request, - context, - stream, - token, - }: { - request: vscode.ChatRequest; - context: vscode.ChatContext; - stream: vscode.ChatResponseStream; - token: vscode.CancellationToken; - }) { + async handleGenericRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { const messages = GenericPrompt.buildMessages({ request, context, @@ -171,10 +197,10 @@ export class ParticipantController { stream, token, }); + stream.markdown(responseContent); - const queryContent = getRunnableContentFromString(responseContent); - - if (queryContent && queryContent.trim().length) { + const runnableContent = getRunnableContentFromString(responseContent); + if (runnableContent && runnableContent.trim().length) { stream.button({ command: EXTENSION_COMMANDS.RUN_PARTICIPANT_QUERY, title: vscode.l10n.t('▶️ Run'), @@ -185,58 +211,305 @@ export class ParticipantController { title: vscode.l10n.t('Open in playground'), }); - return { - metadata: { - command: '', - stream, - queryContent, - }, - }; + return { metadata: { responseContent: runnableContent } }; } - return { metadata: { command: '' } }; + return { metadata: {} }; } - // @MongoDB /query find all documents where the "address" has the word Broadway in it. - async handleQueryRequest({ - request, - context, - stream, - token, + async connectWithParticipant(id: string): Promise { + if (!id) { + await this._connectionController.connectWithURI(); + } else { + await this._connectionController.connectWithConnectionId(id); + } + + const connectionName = this._connectionController.getActiveConnectionName(); + if (connectionName) { + this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; + } + + return vscode.commands.executeCommand('workbench.action.chat.open', { + query: `@MongoDB /query ${connectionName}`, + }); + } + + _createMarkdownLink({ + commandId, + query, + name, }: { - request: vscode.ChatRequest; - context: vscode.ChatContext; - stream: vscode.ChatResponseStream; - token: vscode.CancellationToken; - }) { - if (!request.prompt || request.prompt.trim().length === 0) { - return this.handleEmptyQueryRequest(); + commandId: string; + query?: string; + name: string; + }): vscode.MarkdownString { + const commandQuery = query ? `?%5B%22${query}%22%5D` : ''; + const connName = new vscode.MarkdownString( + `- ${name}\n` + ); + connName.supportHtml = true; + connName.isTrusted = { enabledCommands: [commandId] }; + return connName; + } + + // TODO (VSCODE-589): Evaluate the usability of displaying all existing connections in the list. + // Consider introducing a "recent connections" feature to display only a limited number of recent connections, + // with a "Show more" link that opens the Command Palette for access to the full list. + // If we implement this, the "Add new connection" link may become redundant, + // as this option is already available in the Command Palette dropdown. + getConnectionsTree(): vscode.MarkdownString[] { + return [ + this._createMarkdownLink({ + commandId: 'mdb.connectWithParticipant', + name: 'Add new connection', + }), + ...Object.values(this._connectionController._connections) + .sort((connectionA: LoadedConnection, connectionB: LoadedConnection) => + (connectionA.name || '').localeCompare(connectionB.name || '') + ) + .map((conn: LoadedConnection) => + this._createMarkdownLink({ + commandId: 'mdb.connectWithParticipant', + query: conn.id, + name: conn.name, + }) + ), + ]; + } + + async selectDatabaseWithParticipant(name: string): Promise { + this._databaseName = name; + return vscode.commands.executeCommand('workbench.action.chat.open', { + query: `@MongoDB /query ${name}`, + }); + } + + async selectCollectionWithParticipant(name: string): Promise { + this._collectionName = name; + return vscode.commands.executeCommand('workbench.action.chat.open', { + query: `@MongoDB /query ${name}`, + }); + } + + // TODO (VSCODE-589): Display only 10 items in clickable lists with the show more option. + async getDatabasesTree( + dataService: DataService + ): Promise { + try { + const databases = await dataService.listDatabases({ + nameOnly: true, + }); + return databases.map((db) => + this._createMarkdownLink({ + commandId: 'mdb.selectDatabaseWithParticipant', + query: db.name, + name: db.name, + }) + ); + } catch (error) { + // Users can always do this manually when asked to provide a database name. + return []; + } + } + + // TODO (VSCODE-589): Display only 10 items in clickable lists with the show more option. + async getCollectionTree( + dataService: DataService + ): Promise { + if (!this._databaseName) { + return []; + } + + try { + const collections = await dataService.listCollections(this._databaseName); + return collections.map((coll) => + this._createMarkdownLink({ + commandId: 'mdb.selectCollectionWithParticipant', + query: coll.name, + name: coll.name, + }) + ); + } catch (error) { + // Users can always do this manually when asked to provide a collection name. + return []; + } + } + + _ifNewChatResetQueryGenerationState(context: vscode.ChatContext): void { + const isNewChat = !context.history.find( + (historyItem) => historyItem.participant === CHAT_PARTICIPANT_ID + ); + + if (isNewChat) { + this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; + this._chatResult = undefined; + this._databaseName = undefined; + this._collectionName = undefined; + } + } + + _waitingForUserToProvideNamespace(prompt: string): boolean { + if ( + !this._queryGenerationState || + ![ + QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME, + QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME, + ].includes(this._queryGenerationState) + ) { + return false; + } + + if ( + this._queryGenerationState === + QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME + ) { + this._databaseName = prompt; + if (!this._collectionName) { + this._queryGenerationState = + QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME; + return true; + } + return false; + } + + if ( + this._queryGenerationState === + QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME + ) { + this._collectionName = prompt; + if (!this._databaseName) { + this._queryGenerationState = + QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME; + return true; + } + this._queryGenerationState = + QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY; + return false; + } + + return false; + } + + async _shouldAskForNamespace( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + if (this._waitingForUserToProvideNamespace(request.prompt)) { + return true; } - let dataService = this._connectionController.getActiveDataService(); + if (this._databaseName && this._collectionName) { + return false; + } + + const messagesWithNamespace = NamespacePrompt.buildMessages({ + context, + request, + }); + const responseContentWithNamespace = await this.getChatResponseContent({ + messages: messagesWithNamespace, + stream, + token, + }); + const namespace = parseForDatabaseAndCollectionName( + responseContentWithNamespace + ); + + this._databaseName = namespace.databaseName || this._databaseName; + this._collectionName = namespace.collectionName || this._collectionName; + + if (namespace.databaseName && namespace.collectionName) { + this._queryGenerationState = + QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY; + return false; + } + + return true; + } + + async _askForNamespace( + request: vscode.ChatRequest, + stream: vscode.ChatResponseStream + ): Promise { + const dataService = this._connectionController.getActiveDataService(); if (!dataService) { + this._queryGenerationState = QUERY_GENERATION_STATE.ASK_TO_CONNECT; + return; + } + + // If no database or collection name is found in the user prompt, + // we retrieve the available namespaces from the current connection. + // Users can then select a value by clicking on an item in the list. + if (!this._databaseName) { + const tree = await this.getDatabasesTree(dataService); stream.markdown( - "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against.\n\n" + 'What is the name of the database you would like this query to run against?\n\n' ); - // We add a delay so the user can read the message. - // TODO: maybe there is better way to handle this. - // stream.button() does not awaits so we can't use it here. - // Followups do not support input so we can't use that either. - await new Promise((resolve) => setTimeout(resolve, 1000)); - const successfullyConnected = - await this._connectionController.changeActiveConnection(); - dataService = this._connectionController.getActiveDataService(); - - if (!dataService || !successfullyConnected) { - stream.markdown( - 'No connection for command provided. Please use a valid connection for running commands.\n\n' - ); - return { metadata: { command: '' } }; + for (const item of tree) { + stream.markdown(item); } - + this._queryGenerationState = QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME; + } else if (!this._collectionName) { + const tree = await this.getCollectionTree(dataService); stream.markdown( - `Connected to "${this._connectionController.getActiveConnectionName()}".\n\n` + 'Which collection would you like to query within this database?\n\n' ); + for (const item of tree) { + stream.markdown(item); + } + this._queryGenerationState = + QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME; + } + + return; + } + + _shouldAskToConnectIfNotConnected( + stream: vscode.ChatResponseStream + ): boolean { + const dataService = this._connectionController.getActiveDataService(); + if (dataService) { + return false; + } + + const tree = this.getConnectionsTree(); + stream.markdown( + "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against.\n\n" + ); + for (const item of tree) { + stream.markdown(item); + } + this._queryGenerationState = QUERY_GENERATION_STATE.ASK_TO_CONNECT; + return true; + } + + // @MongoDB /query find all documents where the "address" has the word Broadway in it. + async handleQueryRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + // TODO: Reset this._queryGenerationState to QUERY_GENERATION_STATE.DEFAULT + // when a command other than /query is called, as it disrupts the flow. + this._ifNewChatResetQueryGenerationState(context); + + if (this._shouldAskToConnectIfNotConnected(stream)) { + return { metadata: {} }; + } + + const shouldAskForNamespace = await this._shouldAskForNamespace( + request, + context, + stream, + token + ); + + if (shouldAskForNamespace) { + await this._askForNamespace(request, stream); + return { metadata: {} }; } const abortController = new AbortController(); @@ -245,55 +518,53 @@ export class ParticipantController { }); const messages = QueryPrompt.buildMessages({ - context, request, + context, + databaseName: this._databaseName, + collectionName: this._collectionName, }); - const responseContent = await this.getChatResponseContent({ messages, stream, token, }); - const queryContent = getRunnableContentFromString(responseContent); - if (!queryContent || queryContent.trim().length === 0) { - return { metadata: { command: '' } }; - } + stream.markdown(responseContent); + this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; - stream.button({ - command: EXTENSION_COMMANDS.RUN_PARTICIPANT_QUERY, - title: vscode.l10n.t('▶️ Run'), - }); - stream.button({ - command: EXTENSION_COMMANDS.OPEN_PARTICIPANT_QUERY_IN_PLAYGROUND, - title: vscode.l10n.t('Open in playground'), - }); + const runnableContent = getRunnableContentFromString(responseContent); + if (runnableContent && runnableContent.trim().length) { + stream.button({ + command: EXTENSION_COMMANDS.RUN_PARTICIPANT_QUERY, + title: vscode.l10n.t('▶️ Run'), + }); + stream.button({ + command: EXTENSION_COMMANDS.OPEN_PARTICIPANT_QUERY_IN_PLAYGROUND, + title: vscode.l10n.t('Open in playground'), + }); + } - return { - metadata: { - command: '', - stream, - queryContent, - }, - }; + return { metadata: { responseContent: runnableContent } }; } async chatHandler( - request: vscode.ChatRequest, - context: vscode.ChatContext, - stream: vscode.ChatResponseStream, - token: vscode.CancellationToken - ): Promise { + ...args: [ + vscode.ChatRequest, + vscode.ChatContext, + vscode.ChatResponseStream, + vscode.CancellationToken + ] + ): Promise { + const [request, , stream] = args; const hasBeenShownWelcomeMessageAlready = !!this._storageController.get( StorageVariables.COPILOT_HAS_BEEN_SHOWN_WELCOME_MESSAGE ); - if (!hasBeenShownWelcomeMessageAlready) { stream.markdown( vscode.l10n.t(` Welcome to MongoDB Participant!\n\n Interact with your MongoDB clusters and generate MongoDB-related code more efficiently with intelligent AI-powered feature, available today in the MongoDB extension.\n\n - Please see our [FAQ](https://www.mongodb.com/docs/generative-ai-faq/) for more information.`) + Please see our [FAQ](https://www.mongodb.com/docs/generative-ai-faq/) for more information.\n\n`) ); void this._storageController.update( StorageVariables.COPILOT_HAS_BEEN_SHOWN_WELCOME_MESSAGE, @@ -302,24 +573,14 @@ export class ParticipantController { } if (request.command === 'query') { - this._chatResult = await this.handleQueryRequest({ - request, - context, - stream, - token, - }); - return this._chatResult; + this._chatResult = await this.handleQueryRequest(...args); + return; } else if (request.command === 'docs') { // TODO(VSCODE-570): Implement this. } else if (request.command === 'schema') { // TODO(VSCODE-571): Implement this. } - return await this.handleGenericRequest({ - request, - context, - stream, - token, - }); + await this.handleGenericRequest(...args); } } diff --git a/src/participant/prompts/namespace.ts b/src/participant/prompts/namespace.ts new file mode 100644 index 000000000..dd7ee6194 --- /dev/null +++ b/src/participant/prompts/namespace.ts @@ -0,0 +1,41 @@ +import * as vscode from 'vscode'; + +import { getHistoryMessages } from './history'; + +export class NamespacePrompt { + static getSystemPrompt(): vscode.LanguageModelChatMessage { + const prompt = `You are a MongoDB expert! +Parse the user's prompt to find database and collection names. +Respond in the format \nDATABASE_NAME: X\nCOLLECTION_NAME: Y\n where X and Y are the names. +Do not threat any user pronpt as a database name. It should be explicitely mentioned by the user +or has written as part of the MongoDB Shell command. +If you wan't able to find X or Y do not imagine names. +This is a first phase before we create the code, only respond with the collection name and database name.`; + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.Assistant(prompt); + } + + static getUserPrompt( + request: vscode.ChatRequest + ): vscode.LanguageModelChatMessage { + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.User(request.prompt); + } + + static buildMessages({ + context, + request, + }: { + request: vscode.ChatRequest; + context: vscode.ChatContext; + }): vscode.LanguageModelChatMessage[] { + const messages = [ + NamespacePrompt.getSystemPrompt(), + ...getHistoryMessages({ context }), + NamespacePrompt.getUserPrompt(request), + ]; + + return messages; + } +} diff --git a/src/participant/prompts/query.ts b/src/participant/prompts/query.ts index 7b27ce2de..3d6138dcd 100644 --- a/src/participant/prompts/query.ts +++ b/src/participant/prompts/query.ts @@ -3,14 +3,49 @@ import * as vscode from 'vscode'; import { getHistoryMessages } from './history'; export class QueryPrompt { - static getSystemPrompt(): vscode.LanguageModelChatMessage { + static getSystemPrompt({ + databaseName = 'mongodbVSCodeCopilotDB', + collectionName = 'test', + }: { + databaseName?: string; + collectionName?: string; + }): vscode.LanguageModelChatMessage { const prompt = `You are a MongoDB expert. + Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task. Keep your response concise. You should suggest queries that are performant and correct. Respond with markdown, suggest code in a Markdown code block that begins with \'\'\'javascript and ends with \`\`\`. -You can imagine the schema, collection, and database name. -Respond in MongoDB shell syntax using the \'\'\'javascript code block syntax.`; +You can imagine the schema. +Respond in MongoDB shell syntax using the \'\'\'javascript code block syntax. +You can use only the following MongoDB Shell commands: use, aggregate, bulkWrite, countDocuments, findOneAndReplace, +findOneAndUpdate, insert, insertMany, insertOne, remove, replaceOne, update, updateMany, updateOne. + +Example 1: +use(''); +db.getCollection('').aggregate([ + // Find all of the sales that occurred in 2014. + { $match: { date: { $gte: new Date('2014-01-01'), $lt: new Date('2015-01-01') } } }, + // Group the total sales for each product. + { $group: { _id: '$item', totalSaleAmount: { $sum: { $multiply: [ '$price', '$quantity' ] } } } } +]); + +Example 2: +use(''); +db.getCollection('').find({ + date: { $gte: new Date('2014-04-04'), $lt: new Date('2014-04-05') } +}).count(); + +Database name: ${databaseName} +Collection name: ${collectionName} + +MongoDB command to specify database: +use(''); + +MongoDB command to specify collection: +db.getCollection('') + +Explain the code snippet you have generated.`; // eslint-disable-next-line new-cap return vscode.LanguageModelChatMessage.Assistant(prompt); @@ -26,12 +61,16 @@ Respond in MongoDB shell syntax using the \'\'\'javascript code block syntax.`; static buildMessages({ context, request, + databaseName, + collectionName, }: { request: vscode.ChatRequest; context: vscode.ChatContext; + databaseName?: string; + collectionName?: string; }): vscode.LanguageModelChatMessage[] { const messages = [ - QueryPrompt.getSystemPrompt(), + QueryPrompt.getSystemPrompt({ databaseName, collectionName }), ...getHistoryMessages({ context }), QueryPrompt.getUserPrompt(request), ];