From e5e47ccc9477ca4593010c777cd50925457c4fb6 Mon Sep 17 00:00:00 2001 From: nenoNaninu Date: Wed, 25 Sep 2024 12:24:34 +0900 Subject: [PATCH] support CancellationToken in the last parameter of a receive method --- src/TypedSignalR.Client/SourceGenerator.cs | 2 +- .../HubConnectionExtensionsBinderTemplate.cs | 6 ++- .../HubConnectionExtensionsTemplate.cs | 47 +++++++++++++++++++ .../Templates/MethodMetadataExtensions.cs | 40 ++++++++++++---- 4 files changed, 82 insertions(+), 13 deletions(-) diff --git a/src/TypedSignalR.Client/SourceGenerator.cs b/src/TypedSignalR.Client/SourceGenerator.cs index 014c4f6..fc90e63 100644 --- a/src/TypedSignalR.Client/SourceGenerator.cs +++ b/src/TypedSignalR.Client/SourceGenerator.cs @@ -208,7 +208,7 @@ private static void GenerateBinderSource(SourceProductionContext context, (Immut { var receiverTypes = ExtractReceiverTypesFromRegisterMethods(context, sourceSymbols, specialSymbols); - var template = new HubConnectionExtensionsBinderTemplate(receiverTypes); + var template = new HubConnectionExtensionsBinderTemplate(receiverTypes, specialSymbols); var source = NormalizeNewLines(template.TransformText()); diff --git a/src/TypedSignalR.Client/Templates/HubConnectionExtensionsBinderTemplate.cs b/src/TypedSignalR.Client/Templates/HubConnectionExtensionsBinderTemplate.cs index c7ba6b4..c037610 100644 --- a/src/TypedSignalR.Client/Templates/HubConnectionExtensionsBinderTemplate.cs +++ b/src/TypedSignalR.Client/Templates/HubConnectionExtensionsBinderTemplate.cs @@ -7,10 +7,12 @@ namespace TypedSignalR.Client.Templates; public sealed class HubConnectionExtensionsBinderTemplate { private readonly IReadOnlyList _receiverTypes; + private readonly SpecialSymbols _specialSymbols; - public HubConnectionExtensionsBinderTemplate(IReadOnlyList receiverTypes) + public HubConnectionExtensionsBinderTemplate(IReadOnlyList receiverTypes, SpecialSymbols specialSymbols) { _receiverTypes = receiverTypes; + _specialSymbols = specialSymbols; } public string TransformText() @@ -104,7 +106,7 @@ private string CreateRegistrationString(TypeMetadata receiverType) private string CreateRegistrationStringCore(MethodMetadata method) { return $$""" - compositeDisposable.Add(global::Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(connection, nameof(receiver.{{method.MethodName}}), {{method.CreateParameterTypeArrayString()}}, HandlerConverter.Convert{{method.CreateTypeArgumentsStringForHandlerConverter()}}(receiver.{{method.MethodName}}))); + compositeDisposable.Add(global::Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(connection, nameof(receiver.{{method.MethodName}}), {{method.CreateParameterTypeArrayString(_specialSymbols)}}, HandlerConverter.Convert{{method.CreateTypeArgumentsStringForHandlerConverter(_specialSymbols)}}(receiver.{{method.MethodName}}))); """; } } diff --git a/src/TypedSignalR.Client/Templates/HubConnectionExtensionsTemplate.cs b/src/TypedSignalR.Client/Templates/HubConnectionExtensionsTemplate.cs index 5b6ceaa..e72cf60 100644 --- a/src/TypedSignalR.Client/Templates/HubConnectionExtensionsTemplate.cs +++ b/src/TypedSignalR.Client/Templates/HubConnectionExtensionsTemplate.cs @@ -261,6 +261,25 @@ private static class HandlerConverter """); } + stringBuilder.AppendLine(""" + + public static global::System.Func Convert(global::System.Func handler) + { + return args => handler(default); + } +"""); + + for (int i = 1; i <= 15; i++) + { + stringBuilder.AppendLine($$""" + + public static global::System.Func Convert<{{CreateTypeParametersString(i)}}>(global::System.Func<{{CreateTypeParametersString(i)}}, global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task> handler) + { + return args => handler({{CreateHandlerArgumentsString(i)}}, default); + } +"""); + } + stringBuilder.AppendLine(""" public static global::System.Func> Convert(global::System.Func> handler) @@ -272,6 +291,7 @@ private static class HandlerConverter }; } """); + for (int i = 1; i <= 16; i++) { stringBuilder.AppendLine($$""" @@ -284,6 +304,33 @@ private static class HandlerConverter return result; }; } +"""); + } + + stringBuilder.AppendLine(""" + + public static global::System.Func> Convert(global::System.Func> handler) + { + return async args => + { + var result = await handler(default).ConfigureAwait(false); + return result; + }; + } +"""); + + for (int i = 1; i <= 15; i++) + { + stringBuilder.AppendLine($$""" + + public static global::System.Func> Convert<{{CreateTypeParametersString(i)}}, TResult>(global::System.Func<{{CreateTypeParametersString(i)}}, global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task> handler) + { + return async args => + { + var result = await handler({{CreateHandlerArgumentsString(i)}}, default).ConfigureAwait(false); + return result; + }; + } """); } } diff --git a/src/TypedSignalR.Client/Templates/MethodMetadataExtensions.cs b/src/TypedSignalR.Client/Templates/MethodMetadataExtensions.cs index 634ec05..404fc02 100644 --- a/src/TypedSignalR.Client/Templates/MethodMetadataExtensions.cs +++ b/src/TypedSignalR.Client/Templates/MethodMetadataExtensions.cs @@ -86,9 +86,13 @@ public static string CreateArgumentsStringExceptCancellationToken(this MethodMet return sb.ToString(); } - public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMetadata methodMetadata) + public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols) { - if (methodMetadata.Parameters.Count == 0) + var parameters = methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol) + ? methodMetadata.Parameters.Take(methodMetadata.Parameters.Count - 1).ToArray() + : methodMetadata.Parameters; + + if (parameters.Count == 0) { if (methodMetadata.IsGenericReturnType) { @@ -98,25 +102,25 @@ public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMet return string.Empty; } - if (methodMetadata.Parameters.Count == 1) + if (parameters.Count == 1) { if (methodMetadata.IsGenericReturnType) { - return $"<{methodMetadata.Parameters[0].TypeName}, {methodMetadata.GenericReturnTypeArgument}>"; + return $"<{parameters[0].TypeName}, {methodMetadata.GenericReturnTypeArgument}>"; } - return $"<{methodMetadata.Parameters[0].TypeName}>"; + return $"<{parameters[0].TypeName}>"; } var sb = new StringBuilder(); sb.Append('<'); - sb.Append(methodMetadata.Parameters[0].TypeName); + sb.Append(parameters[0].TypeName); - for (int i = 1; i < methodMetadata.Parameters.Count; i++) + for (int i = 1; i < parameters.Count; i++) { sb.Append(", "); - sb.Append(methodMetadata.Parameters[i].TypeName); + sb.Append(parameters[i].TypeName); } if (methodMetadata.IsGenericReturnType) @@ -129,9 +133,10 @@ public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMet return sb.ToString(); } - public static string CreateParameterTypeArrayString(this MethodMetadata methodMetadata) + public static string CreateParameterTypeArrayString(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols) { - if (methodMetadata.Parameters.Count == 0) + if (methodMetadata.Parameters.Count == 0 + || (methodMetadata.Parameters.Count == 1 && methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol))) { return "global::System.Type.EmptyTypes"; } @@ -148,6 +153,13 @@ public static string CreateParameterTypeArrayString(this MethodMetadata methodMe for (int i = 1; i < methodMetadata.Parameters.Count; i++) { + if (IsLast(i, methodMetadata.Parameters.Count) + && methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol)) + { + // break if the last parameter is CancellationToken. + break; + } + sb.Append($", typeof({methodMetadata.Parameters[i].FullyQualifiedTypeName})"); } @@ -189,6 +201,14 @@ public static string CreateMethodString(this MethodMetadata methodMetadata, Spec }; } + private static bool IsLast(int index, int length) => index == length - 1; + + private static bool LastParameterEqual(this MethodMetadata methodMetadata, ITypeSymbol typeSymbol) + { + return methodMetadata.Parameters.Count > 0 + && SymbolEqualityComparer.Default.Equals(methodMetadata.Parameters[methodMetadata.Parameters.Count - 1].Type, typeSymbol); + } + private static HubMethodType GetHubMethodType(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols) { var returnTypeSymbol = methodMetadata.ReturnTypeSymbol as INamedTypeSymbol;