Skip to content

Commit

Permalink
support CancellationToken in the last parameter of a receive method
Browse files Browse the repository at this point in the history
  • Loading branch information
nenoNaninu committed Sep 25, 2024
1 parent 039ea72 commit e5e47cc
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/TypedSignalR.Client/SourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ private static void GenerateBinderSource(SourceProductionContext context, (Immut
{
var receiverTypes = ExtractReceiverTypesFromRegisterMethods(context, sourceSymbols, specialSymbols);

var template = new HubConnectionExtensionsBinderTemplate(receiverTypes);
var template = new HubConnectionExtensionsBinderTemplate(receiverTypes, specialSymbols);

var source = NormalizeNewLines(template.TransformText());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ namespace TypedSignalR.Client.Templates;
public sealed class HubConnectionExtensionsBinderTemplate
{
private readonly IReadOnlyList<TypeMetadata> _receiverTypes;
private readonly SpecialSymbols _specialSymbols;

public HubConnectionExtensionsBinderTemplate(IReadOnlyList<TypeMetadata> receiverTypes)
public HubConnectionExtensionsBinderTemplate(IReadOnlyList<TypeMetadata> receiverTypes, SpecialSymbols specialSymbols)
{
_receiverTypes = receiverTypes;
_specialSymbols = specialSymbols;
}

public string TransformText()
Expand Down Expand Up @@ -104,7 +106,7 @@ private string CreateRegistrationString(TypeMetadata receiverType)
private string CreateRegistrationStringCore(MethodMetadata method)
{
return $$"""
compositeDisposable.Add(global::Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(connection, nameof(receiver.{{method.MethodName}}), {{method.CreateParameterTypeArrayString()}}, HandlerConverter.Convert{{method.CreateTypeArgumentsStringForHandlerConverter()}}(receiver.{{method.MethodName}})));
compositeDisposable.Add(global::Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(connection, nameof(receiver.{{method.MethodName}}), {{method.CreateParameterTypeArrayString(_specialSymbols)}}, HandlerConverter.Convert{{method.CreateTypeArgumentsStringForHandlerConverter(_specialSymbols)}}(receiver.{{method.MethodName}})));
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ private static class HandlerConverter
""");
}

stringBuilder.AppendLine("""

public static global::System.Func<object?[], global::System.Threading.Tasks.Task> Convert(global::System.Func<global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task> handler)
{
return args => handler(default);
}
""");

for (int i = 1; i <= 15; i++)
{
stringBuilder.AppendLine($$"""

public static global::System.Func<object?[], global::System.Threading.Tasks.Task> Convert<{{CreateTypeParametersString(i)}}>(global::System.Func<{{CreateTypeParametersString(i)}}, global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task> handler)
{
return args => handler({{CreateHandlerArgumentsString(i)}}, default);
}
""");
}

stringBuilder.AppendLine("""

public static global::System.Func<object?[], global::System.Threading.Tasks.Task<TResult>> Convert<TResult>(global::System.Func<global::System.Threading.Tasks.Task<TResult>> handler)
Expand All @@ -272,6 +291,7 @@ private static class HandlerConverter
};
}
""");

for (int i = 1; i <= 16; i++)
{
stringBuilder.AppendLine($$"""
Expand All @@ -284,6 +304,33 @@ private static class HandlerConverter
return result;
};
}
""");
}

stringBuilder.AppendLine("""

public static global::System.Func<object?[], global::System.Threading.Tasks.Task<TResult>> Convert<TResult>(global::System.Func<global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task<TResult>> handler)
{
return async args =>
{
var result = await handler(default).ConfigureAwait(false);
return result;
};
}
""");

for (int i = 1; i <= 15; i++)
{
stringBuilder.AppendLine($$"""

public static global::System.Func<object?[], global::System.Threading.Tasks.Task<TResult>> Convert<{{CreateTypeParametersString(i)}}, TResult>(global::System.Func<{{CreateTypeParametersString(i)}}, global::System.Threading.CancellationToken, global::System.Threading.Tasks.Task<TResult>> handler)
{
return async args =>
{
var result = await handler({{CreateHandlerArgumentsString(i)}}, default).ConfigureAwait(false);
return result;
};
}
""");
}
}
Expand Down
40 changes: 30 additions & 10 deletions src/TypedSignalR.Client/Templates/MethodMetadataExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ public static string CreateArgumentsStringExceptCancellationToken(this MethodMet
return sb.ToString();
}

public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMetadata methodMetadata)
public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols)
{
if (methodMetadata.Parameters.Count == 0)
var parameters = methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol)
? methodMetadata.Parameters.Take(methodMetadata.Parameters.Count - 1).ToArray()
: methodMetadata.Parameters;

if (parameters.Count == 0)
{
if (methodMetadata.IsGenericReturnType)
{
Expand All @@ -98,25 +102,25 @@ public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMet
return string.Empty;
}

if (methodMetadata.Parameters.Count == 1)
if (parameters.Count == 1)
{
if (methodMetadata.IsGenericReturnType)
{
return $"<{methodMetadata.Parameters[0].TypeName}, {methodMetadata.GenericReturnTypeArgument}>";
return $"<{parameters[0].TypeName}, {methodMetadata.GenericReturnTypeArgument}>";
}

return $"<{methodMetadata.Parameters[0].TypeName}>";
return $"<{parameters[0].TypeName}>";
}

var sb = new StringBuilder();

sb.Append('<');
sb.Append(methodMetadata.Parameters[0].TypeName);
sb.Append(parameters[0].TypeName);

for (int i = 1; i < methodMetadata.Parameters.Count; i++)
for (int i = 1; i < parameters.Count; i++)
{
sb.Append(", ");
sb.Append(methodMetadata.Parameters[i].TypeName);
sb.Append(parameters[i].TypeName);
}

if (methodMetadata.IsGenericReturnType)
Expand All @@ -129,9 +133,10 @@ public static string CreateTypeArgumentsStringForHandlerConverter(this MethodMet
return sb.ToString();
}

public static string CreateParameterTypeArrayString(this MethodMetadata methodMetadata)
public static string CreateParameterTypeArrayString(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols)
{
if (methodMetadata.Parameters.Count == 0)
if (methodMetadata.Parameters.Count == 0
|| (methodMetadata.Parameters.Count == 1 && methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol)))
{
return "global::System.Type.EmptyTypes";
}
Expand All @@ -148,6 +153,13 @@ public static string CreateParameterTypeArrayString(this MethodMetadata methodMe

for (int i = 1; i < methodMetadata.Parameters.Count; i++)
{
if (IsLast(i, methodMetadata.Parameters.Count)
&& methodMetadata.LastParameterEqual(specialSymbols.CancellationTokenSymbol))
{
// break if the last parameter is CancellationToken.
break;
}

sb.Append($", typeof({methodMetadata.Parameters[i].FullyQualifiedTypeName})");
}

Expand Down Expand Up @@ -189,6 +201,14 @@ public static string CreateMethodString(this MethodMetadata methodMetadata, Spec
};
}

private static bool IsLast(int index, int length) => index == length - 1;

private static bool LastParameterEqual(this MethodMetadata methodMetadata, ITypeSymbol typeSymbol)
{
return methodMetadata.Parameters.Count > 0
&& SymbolEqualityComparer.Default.Equals(methodMetadata.Parameters[methodMetadata.Parameters.Count - 1].Type, typeSymbol);
}

private static HubMethodType GetHubMethodType(this MethodMetadata methodMetadata, SpecialSymbols specialSymbols)
{
var returnTypeSymbol = methodMetadata.ReturnTypeSymbol as INamedTypeSymbol;
Expand Down

0 comments on commit e5e47cc

Please sign in to comment.