diff --git a/webapi/Controllers/ChatController.cs b/webapi/Controllers/ChatController.cs index a008418d5..58851bec8 100644 --- a/webapi/Controllers/ChatController.cs +++ b/webapi/Controllers/ChatController.cs @@ -45,6 +45,7 @@ namespace CopilotChat.WebApi.Controllers; public class ChatController : ControllerBase, IDisposable { private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; private readonly List _disposables; private readonly ITelemetryService _telemetryService; private readonly ServiceOptions _serviceOptions; @@ -58,12 +59,14 @@ public class ChatController : ControllerBase, IDisposable public ChatController( ILogger logger, + IHttpClientFactory httpClientFactory, ITelemetryService telemetryService, IOptions serviceOptions, IOptions plannerOptions, IDictionary plugins) { this._logger = logger; + this._httpClientFactory = httpClientFactory; this._telemetryService = telemetryService; this._disposables = new List(); this._serviceOptions = serviceOptions.Value; @@ -335,15 +338,12 @@ await planner.Kernel.ImportAIPluginAsync( var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase); BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(PluginAuthValue)); - HttpClient httpClient = new(); - httpClient.Timeout = TimeSpan.FromSeconds(this._plannerOptions.PluginTimeoutLimitInS); - await planner.Kernel.ImportAIPluginAsync( $"{plugin.NameForModel}Plugin", PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), new OpenApiSkillExecutionParameters { - HttpClient = httpClient, + HttpClient = this._httpClientFactory.CreateClient("Plugin"), IgnoreNonCompliantErrors = true, AuthCallback = requiresAuth ? authenticationProvider.AuthenticateRequestAsync : null }); @@ -395,7 +395,7 @@ await planner.Kernel.ImportAIPluginAsync( PluginUtils.GetPluginManifestUri(plugin.ManifestDomain), new OpenApiSkillExecutionParameters { - HttpClient = new HttpClient(), + HttpClient = this._httpClientFactory.CreateClient("Plugin"), IgnoreNonCompliantErrors = true, AuthCallback = authenticationProvider.AuthenticateRequestAsync }); diff --git a/webapi/Controllers/PluginController.cs b/webapi/Controllers/PluginController.cs index 18dd77765..9c5d3a802 100644 --- a/webapi/Controllers/PluginController.cs +++ b/webapi/Controllers/PluginController.cs @@ -27,15 +27,18 @@ public class PluginController : ControllerBase { private const string PluginStateChanged = "PluginStateChanged"; private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; private readonly IDictionary _availablePlugins; private readonly ChatSessionRepository _sessionRepository; public PluginController( ILogger logger, + IHttpClientFactory httpClientFactory, IDictionary availablePlugins, ChatSessionRepository sessionRepository) { this._logger = logger; + this._httpClientFactory = httpClientFactory; this._availablePlugins = availablePlugins; this._sessionRepository = sessionRepository; } @@ -54,7 +57,7 @@ public async Task GetPluginManifest([FromQuery] Uri manifestDomai // Need to set the user agent to avoid 403s from some sites. request.Headers.Add("User-Agent", Telemetry.HttpUserAgent); - using HttpClient client = new(); + using HttpClient client = this._httpClientFactory.CreateClient("Plugin"); var response = await client.SendAsync(request); if (!response.IsSuccessStatusCode) { diff --git a/webapi/Controllers/SpeechTokenController.cs b/webapi/Controllers/SpeechTokenController.cs index 0ad82ef75..787f8a572 100644 --- a/webapi/Controllers/SpeechTokenController.cs +++ b/webapi/Controllers/SpeechTokenController.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Net; using System.Net.Http; using System.Threading.Tasks; @@ -23,11 +22,15 @@ private sealed class TokenResult } private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; private readonly AzureSpeechOptions _options; - public SpeechTokenController(IOptions options, ILogger logger) + public SpeechTokenController(IOptions options, + ILogger logger, + IHttpClientFactory httpClientFactory) { this._logger = logger; + this._httpClientFactory = httpClientFactory; this._options = options.Value; } @@ -55,16 +58,16 @@ public async Task> GetAsync() private async Task FetchTokenAsync(string fetchUri, string subscriptionKey) { - // TODO: get the HttpClient from the DI container - using var client = new HttpClient(); - client.DefaultRequestHeaders.Add("Ocp-Apim-Subscription-Key", subscriptionKey); - UriBuilder uriBuilder = new(fetchUri); + using var client = this._httpClientFactory.CreateClient(); - var result = await client.PostAsync(uriBuilder.Uri, null); + using var request = new HttpRequestMessage(HttpMethod.Post, fetchUri); + request.Headers.Add("Ocp-Apim-Subscription-Key", subscriptionKey); + + var result = await client.SendAsync(request); if (result.IsSuccessStatusCode) { var response = result.EnsureSuccessStatusCode(); - this._logger.LogDebug("Token Uri: {0}", uriBuilder.Uri.AbsoluteUri); + this._logger.LogDebug("Token Uri: {0}", fetchUri); string token = await result.Content.ReadAsStringAsync(); return new TokenResult { Token = token, ResponseCode = response.StatusCode }; } diff --git a/webapi/Extensions/SemanticKernelExtensions.cs b/webapi/Extensions/SemanticKernelExtensions.cs index 63c2e2d7d..4bcc61f8c 100644 --- a/webapi/Extensions/SemanticKernelExtensions.cs +++ b/webapi/Extensions/SemanticKernelExtensions.cs @@ -3,6 +3,7 @@ using System; using System.IO; using System.Linq; +using System.Net.Http; using System.Reflection; using System.Threading.Tasks; using CopilotChat.WebApi.Hubs; @@ -169,7 +170,7 @@ public static IKernel RegisterChatSkill(this IKernel kernel, IServiceProvider sp private static void InitializeKernelProvider(this WebApplicationBuilder builder) { - builder.Services.AddSingleton(sp => new SemanticKernelProvider(sp, builder.Configuration)); + builder.Services.AddSingleton(sp => new SemanticKernelProvider(sp, builder.Configuration, sp.GetRequiredService())); } /// diff --git a/webapi/Program.cs b/webapi/Program.cs index 0d2406ca7..ff8d001aa 100644 --- a/webapi/Program.cs +++ b/webapi/Program.cs @@ -2,6 +2,7 @@ using System; using System.Diagnostics; +using System.Globalization; using System.Linq; using System.Text.Json; using System.Threading.Tasks; @@ -68,6 +69,14 @@ public static async Task Main(string[] args) TelemetryDebugWriter.IsTracingDisabled = Debugger.IsAttached; + // Add named HTTP clients for IHttpClientFactory + builder.Services.AddHttpClient(); + builder.Services.AddHttpClient("Plugin", httpClient => + { + int timeout = int.Parse(builder.Configuration["Planner:PluginTimeoutLimitInS"] ?? "100", CultureInfo.InvariantCulture); + httpClient.Timeout = TimeSpan.FromSeconds(timeout); + }); + // Add in the rest of the services. builder.Services .AddMaintenanceServices() diff --git a/webapi/Services/SemanticKernelProvider.cs b/webapi/Services/SemanticKernelProvider.cs index 0bbb39737..aafaa9e3b 100644 --- a/webapi/Services/SemanticKernelProvider.cs +++ b/webapi/Services/SemanticKernelProvider.cs @@ -27,11 +27,13 @@ public sealed class SemanticKernelProvider private readonly IServiceProvider _serviceProvider; private readonly IConfiguration _configuration; + private readonly IHttpClientFactory _httpClientFactory; - public SemanticKernelProvider(IServiceProvider serviceProvider, IConfiguration configuration) + public SemanticKernelProvider(IServiceProvider serviceProvider, IConfiguration configuration, IHttpClientFactory httpClientFactory) { this._serviceProvider = serviceProvider; this._configuration = configuration; + this._httpClientFactory = httpClientFactory; } /// @@ -83,11 +85,15 @@ private KernelBuilder WithCompletionBackend(KernelBuilder kernelBuilder) case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIText"); - return kernelBuilder.WithAzureChatCompletionService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey); +#pragma warning disable CA2000 // Dispose objects before losing scope - No need to dispose of HttpClient instances from IHttpClientFactory + return kernelBuilder.WithAzureChatCompletionService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); - return kernelBuilder.WithOpenAIChatCompletionService(openAIOptions.TextModel, openAIOptions.APIKey); + return kernelBuilder.WithOpenAIChatCompletionService(openAIOptions.TextModel, openAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); +#pragma warning restore CA2000 // Dispose objects before losing scope default: throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); @@ -107,12 +113,15 @@ private KernelBuilder WithPlannerBackend(KernelBuilder kernelBuilder) case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIText"); - return kernelBuilder.WithAzureChatCompletionService(plannerOptions.Model, azureAIOptions.Endpoint, azureAIOptions.APIKey); +#pragma warning disable CA2000 // Dispose objects before losing scope - No need to dispose of HttpClient instances from IHttpClientFactory + return kernelBuilder.WithAzureChatCompletionService(plannerOptions.Model, azureAIOptions.Endpoint, azureAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); - return kernelBuilder.WithOpenAIChatCompletionService(plannerOptions.Model, openAIOptions.APIKey); - + return kernelBuilder.WithOpenAIChatCompletionService(plannerOptions.Model, openAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); +#pragma warning restore CA2000 // Dispose objects before losing scope default: throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); } @@ -130,12 +139,15 @@ private KernelBuilder WithEmbeddingBackend(KernelBuilder kernelBuilder) case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase): var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIEmbedding"); - return kernelBuilder.WithAzureTextEmbeddingGenerationService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey); +#pragma warning disable CA2000 // Dispose objects before losing scope - No need to dispose of HttpClient instances from IHttpClientFactory + return kernelBuilder.WithAzureTextEmbeddingGenerationService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); - return kernelBuilder.WithOpenAITextEmbeddingGenerationService(openAIOptions.EmbeddingModel, openAIOptions.APIKey); - + return kernelBuilder.WithOpenAITextEmbeddingGenerationService(openAIOptions.EmbeddingModel, openAIOptions.APIKey, + httpClient: this._httpClientFactory.CreateClient()); +#pragma warning restore CA2000 // Dispose objects before losing scope default: throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'SemanticMemory' settings."); }