Skip to content

Commit

Permalink
support customized agent
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <ihailong@amazon.com>

add missing conversation id

Signed-off-by: Hailong Cui <ihailong@amazon.com>

support customized agent

Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am committed Aug 19, 2024
1 parent 32888dd commit 4c9b581
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 38 deletions.
2 changes: 2 additions & 0 deletions common/types/chat_saved_object_attributes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export interface IConversation {
createdTimeMs: number;
updatedTimeMs: number;
messages: IMessage[];
additionalInfo: { [key: string]: unknown };
interactions: Interaction[];
}

Expand All @@ -45,6 +46,7 @@ export interface IInput {
appId?: string;
content?: string;
datasourceId?: string;
agentRole?: string;
};
messageId?: string;
promptPrefix?: string;
Expand Down
6 changes: 5 additions & 1 deletion public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ export const HeaderChatButton = (props: HeaderChatButtonProps) => {
type: 'input',
contentType: 'text',
content: event.suggestion,
context: { appId, content: event.contextContent, datasourceId: event.datasourceId },
context: {
appId,
content: event.contextContent,
datasourceId: event.datasourceId,
},
});
};
registry.on('onSuggestion', handleSuggestion);
Expand Down
8 changes: 5 additions & 3 deletions public/components/incontext_insight/generate_popover_body.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { IncontextInsight as IncontextInsightInput } from '../../types';
import { getConfigSchema, getIncontextInsightRegistry, getNotifications } from '../../services';
import { HttpSetup } from '../../../../../src/core/public';
import { ASSISTANT_API } from '../../../common/constants/llm';
import { getAssistantRole } from '../../utils/constants';

export const GeneratePopoverBody: React.FC<{
incontextInsight: IncontextInsightInput;
Expand Down Expand Up @@ -60,8 +59,11 @@ export const GeneratePopoverBody: React.FC<{
type: 'input',
content: summarizationQuestion,
contentType: 'text',
context: { content: contextContent, dataSourceId: incontextInsight.datasourceId },
promptPrefix: getAssistantRole(incontextInsightType),
context: {
content: contextContent,
agentRole: incontextInsightType,
dataSourceId: incontextInsight.datasourceId,
},
},
}),
})
Expand Down
12 changes: 1 addition & 11 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const llmRequestRoute = {
appId: schema.maybe(schema.string()),
content: schema.maybe(schema.string()),
datasourceId: schema.maybe(schema.string()),
agentRole: schema.maybe(schema.string()),
}),
content: schema.string(),
contentType: schema.literal('text'),
Expand Down Expand Up @@ -235,17 +236,6 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
: [];
}

resultPayload.messages
.filter((message) => message.type === 'input')
.forEach((msg) => {
// hide additional conetxt to how was it generated
const index = msg.content.indexOf('answer question:');
const len = 'answer question:'.length;
if (index !== -1) {
msg.content = msg.content.substring(index + len);
}
});

return response.ok({
body: resultPayload,
});
Expand Down
15 changes: 15 additions & 0 deletions server/routes/get_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ export const getAgent = async (id: string, client: OpenSearchClient['transport']
throw new Error(`get agent ${id} failed, reason: ${errorMessage}`);
}
};

export const getAgentDetail = async (agentId: string, client: OpenSearchClient['transport']) => {
try {
const path = `${ML_COMMONS_BASE_API}/agents/${agentId}`;
const response = await client.request({
method: 'GET',
path,
});

return response.body;
} catch (error) {
const errorMessage = JSON.stringify(error.meta?.body) || error;
throw new Error(`get agent ${agentId} failed, reason: ${errorMessage}`);
}
};
130 changes: 107 additions & 23 deletions server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ import { OpenSearchClient } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { ChatService } from './chat_service';
import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants';
import { getAgent } from '../../routes/get_agent';
import { getAgent, getAgentDetail } from '../../routes/get_agent';
import { AgentRoles } from '../../../server/types';

interface AgentRunPayload {
question?: string;
verbose?: boolean;
memory_id?: string;
regenerate_interaction_id?: string;
'prompt.prefix'?: string;
agentRole?: string;
context?: string;
}

const MEMORY_ID_FIELD = 'memory_id';
Expand All @@ -30,21 +33,114 @@ export class OllyChatService implements ChatService {
return await getAgent(ROOT_AGENT_CONFIG_ID, this.opensearchClientTransport);
}

/**
* @param conversationId conversation/memory Id
* @returns additional information associated with the conversation/memory
*/
private async getAdditionalInfoForConversation(
conversationId: string
): Promise<Record<string, string> | undefined> {
try {
const response = await this.opensearchClientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/${conversationId}`,
});

return response?.body?.additional_info;
} catch (error) {
return undefined;
}
}

private async createNewConversation(
title?: string,
applicationType?: string,
additionalInfo?: Record<string, string>
): Promise<string | undefined> {
try {
const response = (await this.opensearchClientTransport.request({
method: 'POST',
path: `${ML_COMMONS_BASE_API}/memory`,
body: {
name: title,
application_type: applicationType,
additional_info: {
...additionalInfo,
},
},
})) as ApiResponse<{
memory_id: string;
}>;

return response.body.memory_id;
} catch (error) {
return undefined;
}
}

private async requestAgentRun(payload: AgentRunPayload) {
if (payload.memory_id) {
OllyChatService.abortControllers.set(payload.memory_id, new AbortController());
}

const rootAgentId = await this.getRootAgent();
return await this.callExecuteAgentAPI(payload, rootAgentId);
let memoryId = payload.memory_id;
let agentConfigId = ROOT_AGENT_CONFIG_ID;
let agentId;
let promptPrefix;

// follow up questions
if (memoryId) {
const additionalInfo = await this.getAdditionalInfoForConversation(memoryId);
if (additionalInfo) {
agentConfigId = additionalInfo.agent_config_id;
payload.agentRole = payload.agentRole || additionalInfo.agentRole;
payload.context = additionalInfo.context;
}
}

if (payload.agentRole) {
const agentRole = AgentRoles.find((role) => role.id === payload.agentRole);
if (agentRole) {
agentConfigId = agentRole?.agentConfigId;
promptPrefix = agentRole.description;

if (promptPrefix && promptPrefix.length) {
payload['prompt.prefix'] = promptPrefix;
}

agentId = await getAgent(agentConfigId, this.opensearchClientTransport);

// start with a new conversation
if (!memoryId) {
const agentDetail = await getAgentDetail(agentId, this.opensearchClientTransport);

memoryId = await this.createNewConversation(payload.question, agentDetail.app_type, {
agent_config_id: agentRole?.agentConfigId || ROOT_AGENT_CONFIG_ID,
...(agentRole ? { agentRole: agentRole.id } : {}),
...(payload.context ? { context: payload.context } : {}),
});
// set memory id
payload.memory_id = memoryId;
}
}
}

if (!agentId) {
agentId = await getAgent(
agentConfigId || ROOT_AGENT_CONFIG_ID,
this.opensearchClientTransport
);
}

return await this.callExecuteAgentAPI(payload, agentId);
}

private async callExecuteAgentAPI(payload: AgentRunPayload, rootAgentId: string) {
private async callExecuteAgentAPI(payload: AgentRunPayload, agentId: string) {
try {
const agentFrameworkResponse = (await this.opensearchClientTransport.request(
{
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`,
path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`,
body: {
parameters: payload,
},
Expand Down Expand Up @@ -95,28 +191,16 @@ export class OllyChatService implements ChatService {
conversationId: string;
interactionId: string;
}> {
const { input, conversationId } = payload;
const { input } = payload;

let llmInput = input.content;
if (input.context?.content) {
llmInput = `Based on the context: ${input.context?.content}, answer question: ${input.content}`;
}
const parametersPayload: Pick<
AgentRunPayload,
'question' | 'verbose' | 'memory_id' | 'prompt.prefix'
> = {
question: llmInput,
const parametersPayload: AgentRunPayload = {
question: input.content,
context: input.context?.content,
verbose: false,
agentRole: input.context?.agentRole,
memory_id: payload.conversationId,
};

if (input.promptPrefix) {
parametersPayload['prompt.prefix'] = input.promptPrefix;
}

if (conversationId) {
parametersPayload.memory_id = conversationId;
}

return await this.requestAgentRun(parametersPayload);
}

Expand Down
9 changes: 9 additions & 0 deletions server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,12 @@ declare module '../../../src/core/server' {
};
}
}

export const AgentRoles = [
{
id: 'alerts',
agentConfigId: 'alert_analysis',
description:
'Assistant is an advanced alert summarization and analysis agent.For each alert, we will provide a comprehensive detail of the alert, including relevant information. Here is the detail of alert ${parameters.context}',
},
];

0 comments on commit 4c9b581

Please sign in to comment.