From b5356e1f6d6a9182d856dc49d2721d5d44f4001a Mon Sep 17 00:00:00 2001 From: Mollie Munoz Date: Fri, 17 Nov 2023 15:47:48 -0800 Subject: [PATCH] Changes to complete revert --- .../CopilotChatMemoryPipeline.csproj | 4 +- shared/CopilotChatShared.csproj | 2 +- webapi/Controllers/ChatController.cs | 684 +++++++++--------- webapi/CopilotChatWebApi.csproj | 28 +- webapi/Services/SemanticKernelProvider.cs | 310 ++++---- 5 files changed, 514 insertions(+), 514 deletions(-) diff --git a/memorypipeline/CopilotChatMemoryPipeline.csproj b/memorypipeline/CopilotChatMemoryPipeline.csproj index 8f9e4ffa0..fc2cf32dd 100644 --- a/memorypipeline/CopilotChatMemoryPipeline.csproj +++ b/memorypipeline/CopilotChatMemoryPipeline.csproj @@ -15,8 +15,8 @@ - - + + diff --git a/shared/CopilotChatShared.csproj b/shared/CopilotChatShared.csproj index 6a46f5134..a675325d0 100644 --- a/shared/CopilotChatShared.csproj +++ b/shared/CopilotChatShared.csproj @@ -9,7 +9,7 @@ - + diff --git a/webapi/Controllers/ChatController.cs b/webapi/Controllers/ChatController.cs index a01c48e3e..ced7da03d 100644 --- a/webapi/Controllers/ChatController.cs +++ b/webapi/Controllers/ChatController.cs @@ -43,386 +43,386 @@ namespace CopilotChat.WebApi.Controllers; [ApiController] public class ChatController : ControllerBase, IDisposable { - private readonly ILogger _logger; - private readonly IHttpClientFactory _httpClientFactory; - private readonly List _disposables; - private readonly ITelemetryService _telemetryService; - private readonly ServiceOptions _serviceOptions; - private readonly PlannerOptions _plannerOptions; - private readonly IDictionary _plugins; - - private const string ChatPluginName = nameof(ChatPlugin); - private const string ChatFunctionName = "Chat"; - private const string ProcessPlanFunctionName = "ProcessPlan"; - private const string GeneratingResponseClientCall = "ReceiveBotResponseStatus"; - - public ChatController( - ILogger logger, - IHttpClientFactory httpClientFactory, - ITelemetryService telemetryService, - IOptions serviceOptions, - IOptions plannerOptions, - IDictionary plugins) + private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; + private readonly List _disposables; + private readonly ITelemetryService _telemetryService; + private readonly ServiceOptions _serviceOptions; + private readonly PlannerOptions _plannerOptions; + private readonly IDictionary _plugins; + + private const string ChatPluginName = nameof(ChatPlugin); + private const string ChatFunctionName = "Chat"; + private const string ProcessPlanFunctionName = "ProcessPlan"; + private const string GeneratingResponseClientCall = "ReceiveBotResponseStatus"; + + public ChatController( + ILogger logger, + IHttpClientFactory httpClientFactory, + ITelemetryService telemetryService, + IOptions serviceOptions, + IOptions plannerOptions, + IDictionary plugins) + { + this._logger = logger; + this._httpClientFactory = httpClientFactory; + this._telemetryService = telemetryService; + this._disposables = new List(); + this._serviceOptions = serviceOptions.Value; + this._plannerOptions = plannerOptions.Value; + this._plugins = plugins; + } + + /// + /// Invokes the chat function to get a response from the bot. + /// + /// Semantic kernel obtained through dependency injection. + /// Message Hub that performs the real time relay service. + /// Planner to use to create function sequences. + /// Converter to use for converting Asks. + /// Repository of chat sessions. + /// Repository of chat participants. + /// Auth info for the current request. + /// Prompt along with its parameters. + /// Chat ID. + /// Results containing the response from the model. + [Route("chats/{chatId:guid}/messages")] + [HttpPost] + [ProducesResponseType(StatusCodes.Status200OK)] + [ProducesResponseType(StatusCodes.Status400BadRequest)] + [ProducesResponseType(StatusCodes.Status403Forbidden)] + [ProducesResponseType(StatusCodes.Status404NotFound)] + [ProducesResponseType(StatusCodes.Status504GatewayTimeout)] + public async Task ChatAsync( + [FromServices] IKernel kernel, + [FromServices] IHubContext messageRelayHubContext, + [FromServices] CopilotChatPlanner planner, + [FromServices] AskConverter askConverter, + [FromServices] ChatSessionRepository chatSessionRepository, + [FromServices] ChatParticipantRepository chatParticipantRepository, + [FromServices] IAuthInfo authInfo, + [FromBody] Ask ask, + [FromRoute] Guid chatId) + { + this._logger.LogDebug("Chat message received."); + + return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + } + + /// + /// Invokes the chat function to process and/or execute plan. + /// + /// Semantic kernel obtained through dependency injection. + /// Message Hub that performs the real time relay service. + /// Planner to use to create function sequences. + /// Converter to use for converting Asks. + /// Repository of chat sessions. + /// Repository of chat participants. + /// Auth info for the current request. + /// Prompt along with its parameters. + /// Chat ID. + /// Results containing the response from the model. + [Route("chats/{chatId:guid}/plan")] + [HttpPost] + [ProducesResponseType(StatusCodes.Status200OK)] + [ProducesResponseType(StatusCodes.Status400BadRequest)] + [ProducesResponseType(StatusCodes.Status403Forbidden)] + [ProducesResponseType(StatusCodes.Status404NotFound)] + [ProducesResponseType(StatusCodes.Status504GatewayTimeout)] + public async Task ProcessPlanAsync( + [FromServices] IKernel kernel, + [FromServices] IHubContext messageRelayHubContext, + [FromServices] CopilotChatPlanner planner, + [FromServices] AskConverter askConverter, + [FromServices] ChatSessionRepository chatSessionRepository, + [FromServices] ChatParticipantRepository chatParticipantRepository, + [FromServices] IAuthInfo authInfo, + [FromBody] ExecutePlanParameters ask, + [FromRoute] Guid chatId) + { + this._logger.LogDebug("plan request received."); + + return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + } + + /// + /// Invokes given function of ChatPlugin. + /// + /// Name of the ChatPlugin function to invoke. + /// Semantic kernel obtained through dependency injection. + /// Message Hub that performs the real time relay service. + /// Planner to use to create function sequences. + /// Converter to use for converting Asks. + /// Repository of chat sessions. + /// Repository of chat participants. + /// Auth info for the current request. + /// Prompt along with its parameters. + /// + /// Results containing the response from the model. + private async Task HandleRequest( + string functionName, + IKernel kernel, + IHubContext messageRelayHubContext, + CopilotChatPlanner planner, + AskConverter askConverter, + ChatSessionRepository chatSessionRepository, + ChatParticipantRepository chatParticipantRepository, + IAuthInfo authInfo, + Ask ask, + string chatId) + { + // Put ask's variables in the context we will use. + var contextVariables = askConverter.GetContextVariables(ask); + + // Verify that the chat exists and that the user has access to it. + ChatSession? chat = null; + if (!(await chatSessionRepository.TryFindByIdAsync(chatId, callback: c => chat = c))) { - this._logger = logger; - this._httpClientFactory = httpClientFactory; - this._telemetryService = telemetryService; - this._disposables = new List(); - this._serviceOptions = serviceOptions.Value; - this._plannerOptions = plannerOptions.Value; - this._plugins = plugins; + return this.NotFound("Failed to find chat session for the chatId specified in variables."); } - /// - /// Invokes the chat function to get a response from the bot. - /// - /// Semantic kernel obtained through dependency injection. - /// Message Hub that performs the real time relay service. - /// Planner to use to create function sequences. - /// Converter to use for converting Asks. - /// Repository of chat sessions. - /// Repository of chat participants. - /// Auth info for the current request. - /// Prompt along with its parameters. - /// Chat ID. - /// Results containing the response from the model. - [Route("chats/{chatId:guid}/messages")] - [HttpPost] - [ProducesResponseType(StatusCodes.Status200OK)] - [ProducesResponseType(StatusCodes.Status400BadRequest)] - [ProducesResponseType(StatusCodes.Status403Forbidden)] - [ProducesResponseType(StatusCodes.Status404NotFound)] - [ProducesResponseType(StatusCodes.Status504GatewayTimeout)] - public async Task ChatAsync( - [FromServices] IKernel kernel, - [FromServices] IHubContext messageRelayHubContext, - [FromServices] CopilotChatPlanner planner, - [FromServices] AskConverter askConverter, - [FromServices] ChatSessionRepository chatSessionRepository, - [FromServices] ChatParticipantRepository chatParticipantRepository, - [FromServices] IAuthInfo authInfo, - [FromBody] Ask ask, - [FromRoute] Guid chatId) + if (!(await chatParticipantRepository.IsUserInChatAsync(authInfo.UserId, chatId))) { - this._logger.LogDebug("Chat message received."); + return this.Forbid("User does not have access to the chatId specified in variables."); + } + + // Register plugins that have been enabled + var openApiPluginAuthHeaders = this.GetPluginAuthHeaders(this.HttpContext.Request.Headers); + await this.RegisterPlannerFunctionsAsync(planner, openApiPluginAuthHeaders, contextVariables); + + // Register hosted plugins that have been enabled + await this.RegisterPlannerHostedFunctionsUsedAsync(planner, chat!.EnabledPlugins); - return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + // Get the function to invoke + ISKFunction? function = null; + try + { + function = kernel.Functions.GetFunction(ChatPluginName, functionName); + } + catch (SKException ex) + { + this._logger.LogError("Failed to find {PluginName}/{FunctionName} on server: {Exception}", ChatPluginName, functionName, ex); + return this.NotFound($"Failed to find {ChatPluginName}/{functionName} on server"); } - /// - /// Invokes the chat function to process and/or execute plan. - /// - /// Semantic kernel obtained through dependency injection. - /// Message Hub that performs the real time relay service. - /// Planner to use to create function sequences. - /// Converter to use for converting Asks. - /// Repository of chat sessions. - /// Repository of chat participants. - /// Auth info for the current request. - /// Prompt along with its parameters. - /// Chat ID. - /// Results containing the response from the model. - [Route("chats/{chatId:guid}/plan")] - [HttpPost] - [ProducesResponseType(StatusCodes.Status200OK)] - [ProducesResponseType(StatusCodes.Status400BadRequest)] - [ProducesResponseType(StatusCodes.Status403Forbidden)] - [ProducesResponseType(StatusCodes.Status404NotFound)] - [ProducesResponseType(StatusCodes.Status504GatewayTimeout)] - public async Task ProcessPlanAsync( - [FromServices] IKernel kernel, - [FromServices] IHubContext messageRelayHubContext, - [FromServices] CopilotChatPlanner planner, - [FromServices] AskConverter askConverter, - [FromServices] ChatSessionRepository chatSessionRepository, - [FromServices] ChatParticipantRepository chatParticipantRepository, - [FromServices] IAuthInfo authInfo, - [FromBody] ExecutePlanParameters ask, - [FromRoute] Guid chatId) + // Run the function. + KernelResult? result = null; + try { - this._logger.LogDebug("plan request received."); + using CancellationTokenSource? cts = this._serviceOptions.TimeoutLimitInS is not null + // Create a cancellation token source with the timeout if specified + ? new CancellationTokenSource(TimeSpan.FromSeconds((double)this._serviceOptions.TimeoutLimitInS)) + : null; - return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString()); + result = await kernel.RunAsync(function!, contextVariables, cts?.Token ?? default); + this._telemetryService.TrackPluginFunction(ChatPluginName, functionName, true); } - - /// - /// Invokes given function of ChatPlugin. - /// - /// Name of the ChatPlugin function to invoke. - /// Semantic kernel obtained through dependency injection. - /// Message Hub that performs the real time relay service. - /// Planner to use to create function sequences. - /// Converter to use for converting Asks. - /// Repository of chat sessions. - /// Repository of chat participants. - /// Auth info for the current request. - /// Prompt along with its parameters. - /// - /// Results containing the response from the model. - private async Task HandleRequest( - string functionName, - IKernel kernel, - IHubContext messageRelayHubContext, - CopilotChatPlanner planner, - AskConverter askConverter, - ChatSessionRepository chatSessionRepository, - ChatParticipantRepository chatParticipantRepository, - IAuthInfo authInfo, - Ask ask, - string chatId) + catch (Exception ex) { - // Put ask's variables in the context we will use. - var contextVariables = askConverter.GetContextVariables(ask); + if (ex is OperationCanceledException || ex.InnerException is OperationCanceledException) + { + // Log the timeout and return a 504 response + this._logger.LogError("The {FunctionName} operation timed out.", functionName); + return this.StatusCode(StatusCodes.Status504GatewayTimeout, $"The chat {functionName} timed out."); + } + + this._telemetryService.TrackPluginFunction(ChatPluginName, functionName, false); + throw ex; + } - // Verify that the chat exists and that the user has access to it. - ChatSession? chat = null; - if (!(await chatSessionRepository.TryFindByIdAsync(chatId, callback: c => chat = c))) - { - return this.NotFound("Failed to find chat session for the chatId specified in variables."); - } + AskResult chatAskResult = new() + { + Value = result.GetValue() ?? string.Empty, + Variables = contextVariables.Select(v => new KeyValuePair(v.Key, v.Value)) + }; - if (!(await chatParticipantRepository.IsUserInChatAsync(authInfo.UserId, chatId))) - { - return this.Forbid("User does not have access to the chatId specified in variables."); - } + // Broadcast AskResult to all users + await messageRelayHubContext.Clients.Group(chatId).SendAsync(GeneratingResponseClientCall, chatId, null); - // Register plugins that have been enabled - var openApiPluginAuthHeaders = this.GetPluginAuthHeaders(this.HttpContext.Request.Headers); - await this.RegisterPlannerFunctionsAsync(planner, openApiPluginAuthHeaders, contextVariables); + return this.Ok(chatAskResult); + } - // Register hosted plugins that have been enabled - await this.RegisterPlannerHostedFunctionsUsedAsync(planner, chat!.EnabledPlugins); + /// + /// Parse plugin auth values from request headers. + /// + private Dictionary GetPluginAuthHeaders(IHeaderDictionary headers) + { + // Create a regex to match the headers + var regex = new Regex("x-sk-copilot-(.*)-auth", RegexOptions.IgnoreCase); - // Get the function to invoke - ISKFunction? function = null; - try - { - function = kernel.Functions.GetFunction(ChatPluginName, functionName); - } - catch (SKException ex) - { - this._logger.LogError("Failed to find {PluginName}/{FunctionName} on server: {Exception}", ChatPluginName, functionName, ex); - return this.NotFound($"Failed to find {ChatPluginName}/{functionName} on server"); - } + // Create a dictionary to store the matched headers and values + var authHeaders = new Dictionary(); - // Run the function. - KernelResult? result = null; - try - { - using CancellationTokenSource? cts = this._serviceOptions.TimeoutLimitInS is not null - // Create a cancellation token source with the timeout if specified - ? new CancellationTokenSource(TimeSpan.FromSeconds((double)this._serviceOptions.TimeoutLimitInS)) - : null; - - result = await kernel.RunAsync(function!, contextVariables, cts?.Token ?? default); - this._telemetryService.TrackPluginFunction(ChatPluginName, functionName, true); - } - catch (Exception ex) - { - if (ex is OperationCanceledException || ex.InnerException is OperationCanceledException) - { - // Log the timeout and return a 504 response - this._logger.LogError("The {FunctionName} operation timed out.", functionName); - return this.StatusCode(StatusCodes.Status504GatewayTimeout, $"The chat {functionName} timed out."); - } - - this._telemetryService.TrackPluginFunction(ChatPluginName, functionName, false); - throw ex; - } + // Loop through the request headers and add the matched ones to the dictionary + foreach (var header in headers) + { + var match = regex.Match(header.Key); + if (match.Success) + { + // Use the first capture group as the key and the header value as the value + authHeaders.Add(match.Groups[1].Value.ToUpperInvariant(), header.Value!); + } + } - AskResult chatAskResult = new() - { - Value = result.GetValue() ?? string.Empty, - Variables = contextVariables.Select(v => new KeyValuePair(v.Key, v.Value)) - }; + return authHeaders; + } - // Broadcast AskResult to all users - await messageRelayHubContext.Clients.Group(chatId).SendAsync(GeneratingResponseClientCall, chatId, null); + /// + /// Register functions with the planner's kernel. + /// + private async Task RegisterPlannerFunctionsAsync(CopilotChatPlanner planner, Dictionary authHeaders, ContextVariables variables) + { + // Register authenticated functions with the planner's kernel only if the request includes an auth header for the plugin. - return this.Ok(chatAskResult); + // GitHub + if (authHeaders.TryGetValue("GITHUB", out string? GithubAuthHeader)) + { + this._logger.LogInformation("Enabling GitHub plugin."); + BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader)); + await planner.Kernel.ImportPluginFunctionsAsync( + pluginName: "GitHubPlugin", + filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/GitHubPlugin/openapi.json"), + new OpenApiFunctionExecutionParameters + { + AuthCallback = authenticationProvider.AuthenticateRequestAsync, + }); } - /// - /// Parse plugin auth values from request headers. - /// - private Dictionary GetPluginAuthHeaders(IHeaderDictionary headers) + // Jira + if (authHeaders.TryGetValue("JIRA", out string? JiraAuthHeader)) { - // Create a regex to match the headers - var regex = new Regex("x-sk-copilot-(.*)-auth", RegexOptions.IgnoreCase); - - // Create a dictionary to store the matched headers and values - var authHeaders = new Dictionary(); + this._logger.LogInformation("Registering Jira plugin"); + var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); }); + var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out string? serverUrlOverride); + + await planner.Kernel.ImportPluginFunctionsAsync( + pluginName: "JiraPlugin", + filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/JiraPlugin/openapi.json"), + new OpenApiFunctionExecutionParameters + { + AuthCallback = authenticationProvider.AuthenticateRequestAsync, + ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!) : null, + }); + } - // Loop through the request headers and add the matched ones to the dictionary - foreach (var header in headers) - { - var match = regex.Match(header.Key); - if (match.Success) - { - // Use the first capture group as the key and the header value as the value - authHeaders.Add(match.Groups[1].Value.ToUpperInvariant(), header.Value!); - } - } + // Microsoft Graph + if (authHeaders.TryGetValue("GRAPH", out string? GraphAuthHeader)) + { + this._logger.LogInformation("Enabling Microsoft Graph plugin(s)."); + BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader)); + GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.AuthenticateRequestAsync); - return authHeaders; + planner.Kernel.ImportFunctions(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo"); + planner.Kernel.ImportFunctions(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar"); + planner.Kernel.ImportFunctions(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email"); } - /// - /// Register functions with the planner's kernel. - /// - private async Task RegisterPlannerFunctionsAsync(CopilotChatPlanner planner, Dictionary authHeaders, ContextVariables variables) + if (variables.TryGetValue("customPlugins", out string? customPluginsString)) { - // Register authenticated functions with the planner's kernel only if the request includes an auth header for the plugin. + CustomPlugin[]? customPlugins = JsonSerializer.Deserialize(customPluginsString); - // GitHub - if (authHeaders.TryGetValue("GITHUB", out string? GithubAuthHeader)) + if (customPlugins != null) + { + foreach (CustomPlugin plugin in customPlugins) { - this._logger.LogInformation("Enabling GitHub plugin."); - BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader)); - await planner.Kernel.ImportOpenApiPluginFunctionsAsync( - pluginName: "GitHubPlugin", - filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/GitHubPlugin/openapi.json"), + if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue)) + { + // Register the ChatGPT plugin with the planner's kernel. + this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman); + + // TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth. + var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase); + BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(PluginAuthValue)); + + await planner.Kernel.ImportPluginFunctionsAsync( + $"{plugin.NameForModel}Plugin", + PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), new OpenApiFunctionExecutionParameters { - AuthCallback = authenticationProvider.AuthenticateRequestAsync, + HttpClient = this._httpClientFactory.CreateClient("Plugin"), + IgnoreNonCompliantErrors = true, + AuthCallback = requiresAuth ? authenticationProvider.AuthenticateRequestAsync : null }); + } } - - // Jira - if (authHeaders.TryGetValue("JIRA", out string? JiraAuthHeader)) - { - this._logger.LogInformation("Registering Jira plugin"); - var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); }); - var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out string? serverUrlOverride); - - await planner.Kernel.ImportOpenApiPluginFunctionsAsync( - pluginName: "JiraPlugin", - filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/JiraPlugin/openapi.json"), - new OpenApiFunctionExecutionParameters - { - AuthCallback = authenticationProvider.AuthenticateRequestAsync, - ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!) : null, - }); - } - - // Microsoft Graph - if (authHeaders.TryGetValue("GRAPH", out string? GraphAuthHeader)) - { - this._logger.LogInformation("Enabling Microsoft Graph plugin(s)."); - BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader)); - GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.AuthenticateRequestAsync); - - planner.Kernel.ImportFunctions(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo"); - planner.Kernel.ImportFunctions(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar"); - planner.Kernel.ImportFunctions(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email"); - } - - if (variables.TryGetValue("customPlugins", out string? customPluginsString)) - { - CustomPlugin[]? customPlugins = JsonSerializer.Deserialize(customPluginsString); - - if (customPlugins != null) - { - foreach (CustomPlugin plugin in customPlugins) - { - if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue)) - { - // Register the ChatGPT plugin with the planner's kernel. - this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman); - - // TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth. - var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase); - BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(PluginAuthValue)); - - await planner.Kernel.ImportOpenApiPluginFunctionsAsync( - $"{plugin.NameForModel}Plugin", - PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), - new OpenApiFunctionExecutionParameters - { - HttpClient = this._httpClientFactory.CreateClient("Plugin"), - IgnoreNonCompliantErrors = true, - AuthCallback = requiresAuth ? authenticationProvider.AuthenticateRequestAsync : null - }); - } - } - } - else - { - this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString); - } - } + } + else + { + this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString); + } } - - /// - /// Create a Microsoft Graph service client. - /// - /// The delegate to authenticate the request. - private GraphServiceClient CreateGraphServiceClient(AuthenticateRequestAsyncDelegate authenticateRequestAsyncDelegate) + } + + /// + /// Create a Microsoft Graph service client. + /// + /// The delegate to authenticate the request. + private GraphServiceClient CreateGraphServiceClient(AuthenticateRequestAsyncDelegate authenticateRequestAsyncDelegate) + { + MsGraphClientLoggingHandler graphLoggingHandler = new(this._logger); + this._disposables.Add(graphLoggingHandler); + + IList graphMiddlewareHandlers = + GraphClientFactory.CreateDefaultHandlers(new DelegateAuthenticationProvider(authenticateRequestAsyncDelegate)); + graphMiddlewareHandlers.Add(graphLoggingHandler); + + HttpClient graphHttpClient = GraphClientFactory.Create(graphMiddlewareHandlers); + this._disposables.Add(graphHttpClient); + + GraphServiceClient graphServiceClient = new(graphHttpClient); + return graphServiceClient; + } + + private async Task RegisterPlannerHostedFunctionsUsedAsync(CopilotChatPlanner planner, HashSet enabledPlugins) + { + foreach (string enabledPlugin in enabledPlugins) { - MsGraphClientLoggingHandler graphLoggingHandler = new(this._logger); - this._disposables.Add(graphLoggingHandler); - - IList graphMiddlewareHandlers = - GraphClientFactory.CreateDefaultHandlers(new DelegateAuthenticationProvider(authenticateRequestAsyncDelegate)); - graphMiddlewareHandlers.Add(graphLoggingHandler); - - HttpClient graphHttpClient = GraphClientFactory.Create(graphMiddlewareHandlers); - this._disposables.Add(graphHttpClient); - - GraphServiceClient graphServiceClient = new(graphHttpClient); - return graphServiceClient; - } - - private async Task RegisterPlannerHostedFunctionsUsedAsync(CopilotChatPlanner planner, HashSet enabledPlugins) - { - foreach (string enabledPlugin in enabledPlugins) - { - if (this._plugins.TryGetValue(enabledPlugin, out Plugin? plugin)) - { - this._logger.LogDebug("Enabling hosted plugin {0}.", plugin.Name); - - CustomAuthenticationProvider authenticationProvider = new( - () => Task.FromResult("X-Functions-Key"), - () => Task.FromResult(plugin.Key)); - - // Register the ChatGPT plugin with the planner's kernel. - await planner.Kernel.ImportOpenApiPluginFunctionsAsync( - PluginUtils.SanitizePluginName(plugin.Name), - PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), - new OpenApiFunctionExecutionParameters - { - HttpClient = this._httpClientFactory.CreateClient("Plugin"), - IgnoreNonCompliantErrors = true, - AuthCallback = authenticationProvider.AuthenticateRequestAsync - }); - } - else + if (this._plugins.TryGetValue(enabledPlugin, out Plugin? plugin)) + { + this._logger.LogDebug("Enabling hosted plugin {0}.", plugin.Name); + + CustomAuthenticationProvider authenticationProvider = new( + () => Task.FromResult("X-Functions-Key"), + () => Task.FromResult(plugin.Key)); + + // Register the ChatGPT plugin with the planner's kernel. + await planner.Kernel.ImportPluginFunctionsAsync( + PluginUtils.SanitizePluginName(plugin.Name), + PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), + new OpenApiFunctionExecutionParameters { - this._logger.LogWarning("Failed to find plugin {0}.", enabledPlugin); - } - } - return; - } - - /// - /// Dispose of the object. - /// - protected virtual void Dispose(bool disposing) - { - if (disposing) - { - foreach (IDisposable disposable in this._disposables) - { - disposable.Dispose(); - } - } + HttpClient = this._httpClientFactory.CreateClient("Plugin"), + IgnoreNonCompliantErrors = true, + AuthCallback = authenticationProvider.AuthenticateRequestAsync + }); + } + else + { + this._logger.LogWarning("Failed to find plugin {0}.", enabledPlugin); + } } - - /// - public void Dispose() + return; + } + + /// + /// Dispose of the object. + /// + protected virtual void Dispose(bool disposing) + { + if (disposing) { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - this.Dispose(disposing: true); - GC.SuppressFinalize(this); + foreach (IDisposable disposable in this._disposables) + { + disposable.Dispose(); + } } + } + + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } } diff --git a/webapi/CopilotChatWebApi.csproj b/webapi/CopilotChatWebApi.csproj index b2ad3f4b3..5c55b59ae 100644 --- a/webapi/CopilotChatWebApi.csproj +++ b/webapi/CopilotChatWebApi.csproj @@ -21,18 +21,18 @@ - - - - - - - - - - - - + + + + + + + + + + + + @@ -57,7 +57,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive @@ -67,7 +67,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/webapi/Services/SemanticKernelProvider.cs b/webapi/Services/SemanticKernelProvider.cs index 746227d2d..157928d81 100644 --- a/webapi/Services/SemanticKernelProvider.cs +++ b/webapi/Services/SemanticKernelProvider.cs @@ -24,194 +24,194 @@ namespace CopilotChat.WebApi.Services; /// public sealed class SemanticKernelProvider { - private static IMemoryStore? _volatileMemoryStore; - - private readonly KernelBuilder _builderChat; - private readonly KernelBuilder _builderPlanner; - private readonly MemoryBuilder _builderMemory; - - public SemanticKernelProvider(IServiceProvider serviceProvider, IConfiguration configuration, IHttpClientFactory httpClientFactory) - { - this._builderChat = InitializeCompletionKernel(serviceProvider, configuration, httpClientFactory); - this._builderPlanner = InitializePlannerKernel(serviceProvider, configuration, httpClientFactory); - this._builderMemory = InitializeMigrationMemory(serviceProvider, configuration, httpClientFactory); - } - - /// - /// Produce semantic-kernel with only completion services for chat. - /// - public IKernel GetCompletionKernel() => this._builderChat.Build(); - - /// - /// Produce semantic-kernel with only completion services for planner. - /// - public IKernel GetPlannerKernel() => this._builderPlanner.Build(); - - /// - /// Produce semantic-kernel with kernel memory. - /// - public ISemanticTextMemory GetMigrationMemory() => this._builderMemory.Build(); - - private static KernelBuilder InitializeCompletionKernel( - IServiceProvider serviceProvider, - IConfiguration configuration, - IHttpClientFactory httpClientFactory) + private static IMemoryStore? _volatileMemoryStore; + + private readonly KernelBuilder _builderChat; + private readonly KernelBuilder _builderPlanner; + private readonly MemoryBuilder _builderMemory; + + public SemanticKernelProvider(IServiceProvider serviceProvider, IConfiguration configuration, IHttpClientFactory httpClientFactory) + { + this._builderChat = InitializeCompletionKernel(serviceProvider, configuration, httpClientFactory); + this._builderPlanner = InitializePlannerKernel(serviceProvider, configuration, httpClientFactory); + this._builderMemory = InitializeMigrationMemory(serviceProvider, configuration, httpClientFactory); + } + + /// + /// Produce semantic-kernel with only completion services for chat. + /// + public IKernel GetCompletionKernel() => this._builderChat.Build(); + + /// + /// Produce semantic-kernel with only completion services for planner. + /// + public IKernel GetPlannerKernel() => this._builderPlanner.Build(); + + /// + /// Produce semantic-kernel with kernel memory. + /// + public ISemanticTextMemory GetMigrationMemory() => this._builderMemory.Build(); + + private static KernelBuilder InitializeCompletionKernel( + IServiceProvider serviceProvider, + IConfiguration configuration, + IHttpClientFactory httpClientFactory) + { + var builder = new KernelBuilder(); + + builder.WithLoggerFactory(serviceProvider.GetRequiredService()); + + var memoryOptions = serviceProvider.GetRequiredService>().Value; + + switch (memoryOptions.TextGeneratorType) { - var builder = new KernelBuilder(); - - builder.WithLoggerFactory(serviceProvider.GetRequiredService()); - - var memoryOptions = serviceProvider.GetRequiredService>().Value; - - switch (memoryOptions.TextGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithAzureOpenAIChatCompletionService( - azureAIOptions.Deployment, - azureAIOptions.Endpoint, - azureAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithAzureChatCompletionService( + azureAIOptions.Deployment, + azureAIOptions.Endpoint, + azureAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; + break; - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithOpenAIChatCompletionService( - openAIOptions.TextModel, - openAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithOpenAIChatCompletionService( + openAIOptions.TextModel, + openAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; - - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'KernelMemory' settings."); - } + break; - return builder; + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'KernelMemory' settings."); } - private static KernelBuilder InitializePlannerKernel( - IServiceProvider serviceProvider, - IConfiguration configuration, - IHttpClientFactory httpClientFactory) - { - var builder = new KernelBuilder(); + return builder; + } + + private static KernelBuilder InitializePlannerKernel( + IServiceProvider serviceProvider, + IConfiguration configuration, + IHttpClientFactory httpClientFactory) + { + var builder = new KernelBuilder(); - builder.WithLoggerFactory(serviceProvider.GetRequiredService()); + builder.WithLoggerFactory(serviceProvider.GetRequiredService()); - var memoryOptions = serviceProvider.GetRequiredService>().Value; - var plannerOptions = serviceProvider.GetRequiredService>().Value; + var memoryOptions = serviceProvider.GetRequiredService>().Value; + var plannerOptions = serviceProvider.GetRequiredService>().Value; - switch (memoryOptions.TextGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); + switch (memoryOptions.TextGeneratorType) + { + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithAzureOpenAIChatCompletionService( - plannerOptions.Model, - azureAIOptions.Endpoint, - azureAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithAzureChatCompletionService( + plannerOptions.Model, + azureAIOptions.Endpoint, + azureAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; + break; - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithOpenAIChatCompletionService( - plannerOptions.Model, - openAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithOpenAIChatCompletionService( + plannerOptions.Model, + openAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; + break; - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'KernelMemory' settings."); - } - - return builder; + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'KernelMemory' settings."); } - private static MemoryBuilder InitializeMigrationMemory( - IServiceProvider serviceProvider, - IConfiguration configuration, - IHttpClientFactory httpClientFactory) - { - var memoryOptions = serviceProvider.GetRequiredService>().Value; + return builder; + } - var builder = new MemoryBuilder(); + private static MemoryBuilder InitializeMigrationMemory( + IServiceProvider serviceProvider, + IConfiguration configuration, + IHttpClientFactory httpClientFactory) + { + var memoryOptions = serviceProvider.GetRequiredService>().Value; - builder.WithLoggerFactory(serviceProvider.GetRequiredService()); - builder.WithMemoryStore(CreateMemoryStore()); + var builder = new MemoryBuilder(); - switch (memoryOptions.Retrieval.EmbeddingGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIEmbedding"); + builder.WithLoggerFactory(serviceProvider.GetRequiredService()); + builder.WithMemoryStore(CreateMemoryStore()); + + switch (memoryOptions.Retrieval.EmbeddingGeneratorType) + { + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIEmbedding"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithAzureOpenAITextEmbeddingGenerationService( - azureAIOptions.Deployment, - azureAIOptions.Endpoint, - azureAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithAzureTextEmbeddingGenerationService( + azureAIOptions.Deployment, + azureAIOptions.Endpoint, + azureAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; + break; - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); #pragma warning disable CA2000 // No need to dispose of HttpClient instances from IHttpClientFactory - builder.WithOpenAITextEmbeddingGenerationService( - openAIOptions.EmbeddingModel, - openAIOptions.APIKey, - httpClient: httpClientFactory.CreateClient()); + builder.WithOpenAITextEmbeddingGenerationService( + openAIOptions.EmbeddingModel, + openAIOptions.APIKey, + httpClient: httpClientFactory.CreateClient()); #pragma warning restore CA2000 - break; + break; - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'KernelMemory' settings."); - } - return builder; + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'KernelMemory' settings."); + } + return builder; - IMemoryStore CreateMemoryStore() - { - switch (memoryOptions.Retrieval.VectorDbType) - { - case string x when x.Equals("SimpleVectorDb", StringComparison.OrdinalIgnoreCase): - // Maintain single instance of volatile memory. - Interlocked.CompareExchange(ref _volatileMemoryStore, new VolatileMemoryStore(), null); - return _volatileMemoryStore; + IMemoryStore CreateMemoryStore() + { + switch (memoryOptions.Retrieval.VectorDbType) + { + case string x when x.Equals("SimpleVectorDb", StringComparison.OrdinalIgnoreCase): + // Maintain single instance of volatile memory. + Interlocked.CompareExchange(ref _volatileMemoryStore, new VolatileMemoryStore(), null); + return _volatileMemoryStore; - case string x when x.Equals("Qdrant", StringComparison.OrdinalIgnoreCase): - var qdrantConfig = memoryOptions.GetServiceConfig(configuration, "Qdrant"); + case string x when x.Equals("Qdrant", StringComparison.OrdinalIgnoreCase): + var qdrantConfig = memoryOptions.GetServiceConfig(configuration, "Qdrant"); #pragma warning disable CA2000 // Ownership passed to QdrantMemoryStore - HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true }); + HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true }); #pragma warning restore CA2000 // Ownership passed to QdrantMemoryStore - if (!string.IsNullOrWhiteSpace(qdrantConfig.APIKey)) - { - httpClient.DefaultRequestHeaders.Add("api-key", qdrantConfig.APIKey); - } - - return - new QdrantMemoryStore( - httpClient: httpClient, - 1536, - qdrantConfig.Endpoint, - loggerFactory: serviceProvider.GetRequiredService()); - - case string x when x.Equals("AzureCognitiveSearch", StringComparison.OrdinalIgnoreCase): - var acsConfig = memoryOptions.GetServiceConfig(configuration, "AzureCognitiveSearch"); - return new AzureCognitiveSearchMemoryStore(acsConfig.Endpoint, acsConfig.APIKey); - - default: - throw new InvalidOperationException($"Invalid 'VectorDbType' type '{memoryOptions.Retrieval.VectorDbType}'."); - } - } + if (!string.IsNullOrWhiteSpace(qdrantConfig.APIKey)) + { + httpClient.DefaultRequestHeaders.Add("api-key", qdrantConfig.APIKey); + } + + return + new QdrantMemoryStore( + httpClient: httpClient, + 1536, + qdrantConfig.Endpoint, + loggerFactory: serviceProvider.GetRequiredService()); + + case string x when x.Equals("AzureCognitiveSearch", StringComparison.OrdinalIgnoreCase): + var acsConfig = memoryOptions.GetServiceConfig(configuration, "AzureCognitiveSearch"); + return new AzureCognitiveSearchMemoryStore(acsConfig.Endpoint, acsConfig.APIKey); + + default: + throw new InvalidOperationException($"Invalid 'VectorDbType' type '{memoryOptions.Retrieval.VectorDbType}'."); + } } + } }