Skip to content

Commit

Permalink
Add htttp handler to facilitate custom headers
Browse files Browse the repository at this point in the history
  • Loading branch information
calebkiage committed Jul 18, 2023
1 parent 12cf784 commit 20c9848
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 9 deletions.
41 changes: 41 additions & 0 deletions src/Microsoft.Graph.Cli.Core.Tests/Http/HeaderStoreTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System.Linq;
using Microsoft.Graph.Cli.Core.Http;
using Xunit;

namespace Microsoft.Graph.Cli.Core.Tests.Http;

public class HeaderStoreTests
{
[Fact]
public void Stores_Single_Header()
{
var header = new [] {"sample=header"};
HeadersStore.Instance.SetHeadersFromStrings(header);

Assert.NotEmpty(HeadersStore.Instance.Headers);
Assert.Equal(HeadersStore.Instance.Headers.Count(), 1);
Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 1);
}

[Fact]
public void Stores_Multiple_Headers()
{
var header = new [] {"sample=header", "sample2=header2", };
HeadersStore.Instance.SetHeadersFromStrings(header);

Assert.NotEmpty(HeadersStore.Instance.Headers);
Assert.Equal(HeadersStore.Instance.Headers.Count(), 2);
Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 1);
}

[Fact]
public void Stores_Multiple_Headers_With_Matching_Key()
{
var header = new [] {"sample=header", "sample=header2", };
HeadersStore.Instance.SetHeadersFromStrings(header);

Assert.NotEmpty(HeadersStore.Instance.Headers);
Assert.Equal(HeadersStore.Instance.Headers.Count(), 1);
Assert.Equal(HeadersStore.Instance.Headers.First().Value.Count, 2);
}
}
68 changes: 68 additions & 0 deletions src/Microsoft.Graph.Cli.Core/Http/HttpHeadersHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Graph.Cli.Core.Http;

public class HttpHeadersHandler : DelegatingHandler
{
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
foreach (var headerItem in HeadersStore.Instance.Headers)
{
request.Headers.Add(headerItem.Key, headerItem.Value);
}
return base.SendAsync(request, cancellationToken);
}
}
public sealed class HeadersStore
{
private readonly Dictionary<string, List<string>> headers = new();
private static readonly Lazy<HeadersStore> lazy =
new Lazy<HeadersStore>(() => new HeadersStore());

public static HeadersStore Instance { get { return lazy.Value; } }

public IEnumerable<KeyValuePair<string, List<string>>> Headers { get { return headers; } }

private HeadersStore()
{
}

public void SetHeadersFromStrings(string[] headers)
{
this.headers.Clear();

var mapped = headers
.Select<string, (string, string)?>(static h =>
{
var split = h.Split('=', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (split.Length < 1)
{
return null;
}
var k = split[0];
var v = string.Empty;
if (split.Length > 1)
{
v = split[1];
}
return (k, v);
});
foreach (var kv in mapped)
{
if (kv is { } nonNull)
{
if (!this.headers.TryAdd(nonNull.Item1, new List<string> { nonNull.Item2 }))
{
this.headers[nonNull.Item1].Add(nonNull.Item2);
}
}
}
}
}
46 changes: 37 additions & 9 deletions src/sample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.IO;
using System.Net.Http;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using ApiSdk;
using Azure.Core.Diagnostics;
Expand Down Expand Up @@ -40,6 +41,8 @@ class Program

static async Task<int> Main(string[] args)
{
Console.InputEncoding = Encoding.Unicode;
Console.OutputEncoding = Encoding.Unicode;
var builder = BuildCommandLine()
.UseDefaults()
.UseHost(CreateHostBuilder)
Expand Down Expand Up @@ -102,18 +105,39 @@ static CommandLineBuilder BuildCommandLine()

var builder = new CommandLineBuilder(rootCommand);
var debugOption = new Option<bool>("--debug", "Enable debug output");
var headersOption = new Option<string[]>("--headers", "Add custom request headers to the request. Can be used multiple times to add many headers. e.g. --headers key1=value1 --headers key2=value2");
headersOption.Arity = ArgumentArity.ZeroOrMore;

builder.AddMiddleware(async (ic, next) =>
{
debugEnabled = ic.ParseResult.GetValueForOption<bool>(debugOption);
if (debugEnabled)
static void AddOptionToCommandIf(ref Command command, in Option option, Func<Command, bool> predicate) {
if (predicate(command)) {
command.AddOption(option);
}

foreach (var sym in command.Children)
{
listener = AzureEventSourceListener.CreateConsoleLogger(EventLevel.LogAlways);
if (sym is Command cmd)
{
AddOptionToCommandIf(ref cmd, option, predicate);
}
}
else
}

AddOptionToCommandIf(ref rootCommand, headersOption, cmd => cmd.Handler is not null);

builder.AddMiddleware(async (ic, next) =>
{
if (ic.ParseResult.GetValueForOption(headersOption) is { } options)
{
listener = AzureEventSourceListener.CreateConsoleLogger(EventLevel.Warning);
HeadersStore.Instance.SetHeadersFromStrings(options);
}
await next(ic);
});

builder.AddMiddleware(async (ic, next) =>
{
debugEnabled = ic.ParseResult.GetValueForOption<bool>(debugOption);
listener = AzureEventSourceListener.CreateConsoleLogger(debugEnabled ? EventLevel.LogAlways : EventLevel.Warning);
await next(ic);
});

Expand All @@ -136,6 +160,7 @@ static IHostBuilder CreateHostBuilder(string[] args) =>
var authSection = ctx.Configuration.GetSection(nameof(AuthenticationOptions));
services.Configure<AuthenticationOptions>(authSection);
services.AddTransient<LoggingHandler>();
services.AddTransient<HttpHeadersHandler>();
services.AddSingleton<HttpClient>(p =>
{
var assemblyVersion = Assembly.GetExecutingAssembly().GetName().Version;
Expand All @@ -145,8 +170,11 @@ static IHostBuilder CreateHostBuilder(string[] args) =>
GraphServiceLibraryClientVersion = $"{assemblyVersion?.Major ?? 0}.{assemblyVersion?.Minor ?? 0}.{assemblyVersion?.Build ?? 0}",
GraphServiceTargetVersion = "1.0"
};
return GraphCliClientFactory.GetDefaultClient(options, loggingHandler: p.GetRequiredService<LoggingHandler>());
return GraphCliClientFactory.GetDefaultClient(
options,
loggingHandler: p.GetRequiredService<LoggingHandler>(),
middlewares: new[] { p.GetRequiredService<HttpHeadersHandler>() }
);
});
services.AddSingleton<IAuthenticationProvider>(p =>
{
Expand Down

0 comments on commit 20c9848

Please sign in to comment.