From 20c9848e91cf51f348117d4fe5d8279cd774e78b Mon Sep 17 00:00:00 2001 From: Caleb Kiage <747955+calebkiage@users.noreply.github.com> Date: Tue, 18 Jul 2023 09:17:34 +0300 Subject: [PATCH] Add htttp handler to facilitate custom headers --- .../Http/HeaderStoreTests.cs | 41 +++++++++++ .../Http/HttpHeadersHandler.cs | 68 +++++++++++++++++++ src/sample/Program.cs | 46 ++++++++++--- 3 files changed, 146 insertions(+), 9 deletions(-) create mode 100644 src/Microsoft.Graph.Cli.Core.Tests/Http/HeaderStoreTests.cs create mode 100644 src/Microsoft.Graph.Cli.Core/Http/HttpHeadersHandler.cs diff --git a/src/Microsoft.Graph.Cli.Core.Tests/Http/HeaderStoreTests.cs b/src/Microsoft.Graph.Cli.Core.Tests/Http/HeaderStoreTests.cs new file mode 100644 index 00000000..9ef52d6c --- /dev/null +++ b/src/Microsoft.Graph.Cli.Core.Tests/Http/HeaderStoreTests.cs @@ -0,0 +1,41 @@ +using System.Linq; +using Microsoft.Graph.Cli.Core.Http; +using Xunit; + +namespace Microsoft.Graph.Cli.Core.Tests.Http; + +public class HeaderStoreTests +{ + [Fact] + public void Stores_Single_Header() + { + var header = new [] {"sample=header"}; + HeadersStore.Instance.SetHeadersFromStrings(header); + + Assert.NotEmpty(HeadersStore.Instance.Headers); + Assert.Equal(HeadersStore.Instance.Headers.Count(), 1); + Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 1); + } + + [Fact] + public void Stores_Multiple_Headers() + { + var header = new [] {"sample=header", "sample2=header2", }; + HeadersStore.Instance.SetHeadersFromStrings(header); + + Assert.NotEmpty(HeadersStore.Instance.Headers); + Assert.Equal(HeadersStore.Instance.Headers.Count(), 2); + Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 1); + } + + [Fact] + public void Stores_Multiple_Headers_With_Matching_Key() + { + var header = new [] {"sample=header", "sample=header2", }; + HeadersStore.Instance.SetHeadersFromStrings(header); + + Assert.NotEmpty(HeadersStore.Instance.Headers); + Assert.Equal(HeadersStore.Instance.Headers.Count(), 1); + Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 2); + } +} diff --git a/src/Microsoft.Graph.Cli.Core/Http/HttpHeadersHandler.cs b/src/Microsoft.Graph.Cli.Core/Http/HttpHeadersHandler.cs new file mode 100644 index 00000000..600ef5fa --- /dev/null +++ b/src/Microsoft.Graph.Cli.Core/Http/HttpHeadersHandler.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Graph.Cli.Core.Http; + +public class HttpHeadersHandler : DelegatingHandler +{ + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + foreach (var headerItem in HeadersStore.Instance.Headers) + { + request.Headers.Add(headerItem.Key, headerItem.Value); + } + return base.SendAsync(request, cancellationToken); + } +} +public sealed class HeadersStore +{ + private readonly Dictionary> headers = new(); + private static readonly Lazy lazy = + new Lazy(() => new HeadersStore()); + + public static HeadersStore Instance { get { return lazy.Value; } } + + public IEnumerable>> Headers { get { return headers; } } + + private HeadersStore() + { + } + + public void SetHeadersFromStrings(string[] headers) + { + this.headers.Clear(); + + var mapped = headers + .Select(static h => + { + var split = h.Split('=', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (split.Length < 1) + { + return null; + } + + var k = split[0]; + var v = string.Empty; + if (split.Length > 1) + { + v = split[1]; + } + + return (k, v); + }); + foreach (var kv in mapped) + { + if (kv is { } nonNull) + { + if (!this.headers.TryAdd(nonNull.Item1, new List { nonNull.Item2 })) + { + this.headers[nonNull.Item1].Add(nonNull.Item2); + } + } + } + } +} diff --git a/src/sample/Program.cs b/src/sample/Program.cs index b642d6e9..8870da02 100644 --- a/src/sample/Program.cs +++ b/src/sample/Program.cs @@ -8,6 +8,7 @@ using System.IO; using System.Net.Http; using System.Reflection; +using System.Text; using System.Threading.Tasks; using ApiSdk; using Azure.Core.Diagnostics; @@ -40,6 +41,8 @@ class Program static async Task Main(string[] args) { + Console.InputEncoding = Encoding.Unicode; + Console.OutputEncoding = Encoding.Unicode; var builder = BuildCommandLine() .UseDefaults() .UseHost(CreateHostBuilder) @@ -102,18 +105,39 @@ static CommandLineBuilder BuildCommandLine() var builder = new CommandLineBuilder(rootCommand); var debugOption = new Option("--debug", "Enable debug output"); + var headersOption = new Option("--headers", "Add custom request headers to the request. Can be used multiple times to add many headers. e.g. --headers key1=value1 --headers key2=value2"); + headersOption.Arity = ArgumentArity.ZeroOrMore; - builder.AddMiddleware(async (ic, next) => - { - debugEnabled = ic.ParseResult.GetValueForOption(debugOption); - if (debugEnabled) + static void AddOptionToCommandIf(ref Command command, in Option option, Func predicate) { + if (predicate(command)) { + command.AddOption(option); + } + + foreach (var sym in command.Children) { - listener = AzureEventSourceListener.CreateConsoleLogger(EventLevel.LogAlways); + if (sym is Command cmd) + { + AddOptionToCommandIf(ref cmd, option, predicate); + } } - else + } + + AddOptionToCommandIf(ref rootCommand, headersOption, cmd => cmd.Handler is not null); + + builder.AddMiddleware(async (ic, next) => + { + if (ic.ParseResult.GetValueForOption(headersOption) is { } options) { - listener = AzureEventSourceListener.CreateConsoleLogger(EventLevel.Warning); + HeadersStore.Instance.SetHeadersFromStrings(options); } + + await next(ic); + }); + + builder.AddMiddleware(async (ic, next) => + { + debugEnabled = ic.ParseResult.GetValueForOption(debugOption); + listener = AzureEventSourceListener.CreateConsoleLogger(debugEnabled ? EventLevel.LogAlways : EventLevel.Warning); await next(ic); }); @@ -136,6 +160,7 @@ static IHostBuilder CreateHostBuilder(string[] args) => var authSection = ctx.Configuration.GetSection(nameof(AuthenticationOptions)); services.Configure(authSection); services.AddTransient(); + services.AddTransient(); services.AddSingleton(p => { var assemblyVersion = Assembly.GetExecutingAssembly().GetName().Version; @@ -145,8 +170,11 @@ static IHostBuilder CreateHostBuilder(string[] args) => GraphServiceLibraryClientVersion = $"{assemblyVersion?.Major ?? 0}.{assemblyVersion?.Minor ?? 0}.{assemblyVersion?.Build ?? 0}", GraphServiceTargetVersion = "1.0" }; - - return GraphCliClientFactory.GetDefaultClient(options, loggingHandler: p.GetRequiredService()); + return GraphCliClientFactory.GetDefaultClient( + options, + loggingHandler: p.GetRequiredService(), + middlewares: new[] { p.GetRequiredService() } + ); }); services.AddSingleton(p => {