diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
index 484440542..6f5c98c7a 100644
--- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
+++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
@@ -69,714 +69,6 @@ out var refitInternalNamespace
}
#endif
- ///
- /// Generates the interface stubs.
- ///
- /// The type of the context.
- /// The context.
- /// The report diagnostic.
- /// The add source.
- /// The compilation.
- /// The refit internal namespace.
- /// The candidate methods.
- /// The candidate interfaces.
- ///
- public void GenerateInterfaceStubs(
- TContext context,
- Action reportDiagnostic,
- Action addSource,
- CSharpCompilation compilation,
- string? refitInternalNamespace,
- ImmutableArray candidateMethods,
- ImmutableArray candidateInterfaces
- )
- {
- refitInternalNamespace =
- $"{refitInternalNamespace ?? string.Empty}RefitInternalGenerated";
-
- // we're going to create a new compilation that contains the attribute.
- // TODO: we should allow source generators to provide source during initialize, so that this step isn't required.
- var options = (CSharpParseOptions)compilation.SyntaxTrees[0].Options;
-
- var disposableInterfaceSymbol = compilation.GetTypeByMetadataName(
- "System.IDisposable"
- )!;
- var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName(
- "Refit.HttpMethodAttribute"
- );
-
- if (httpMethodBaseAttributeSymbol == null)
- {
- reportDiagnostic(
- context,
- Diagnostic.Create(DiagnosticDescriptors.RefitNotReferenced, null)
- );
- return;
- }
-
- // Check the candidates and keep the ones we're actually interested in
-
-#pragma warning disable RS1024 // Compare symbols correctly
- var interfaceToNullableEnabledMap = new Dictionary(
- SymbolEqualityComparer.Default
- );
-#pragma warning restore RS1024 // Compare symbols correctly
- var methodSymbols = new List();
- foreach (var group in candidateMethods.GroupBy(m => m.SyntaxTree))
- {
- var model = compilation.GetSemanticModel(group.Key);
- foreach (var method in group)
- {
- // Get the symbol being declared by the method
- var methodSymbol = model.GetDeclaredSymbol(method);
- if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol))
- {
- var isAnnotated =
- compilation.Options.NullableContextOptions
- == NullableContextOptions.Enable
- || model.GetNullableContext(method.SpanStart)
- == NullableContext.Enabled;
- interfaceToNullableEnabledMap[methodSymbol!.ContainingType] = isAnnotated;
-
- methodSymbols.Add(methodSymbol!);
- }
- }
- }
-
- var interfaces = methodSymbols
- .GroupBy(
- m => m.ContainingType,
- SymbolEqualityComparer.Default
- )
- .ToDictionary(g => g.Key, v => v.ToList());
-
- // Look through the candidate interfaces
- var interfaceSymbols = new List();
- foreach (var group in candidateInterfaces.GroupBy(i => i.SyntaxTree))
- {
- var model = compilation.GetSemanticModel(group.Key);
- foreach (var iface in group)
- {
- // get the symbol belonging to the interface
- var ifaceSymbol = model.GetDeclaredSymbol(iface);
-
- // See if we already know about it, might be a dup
- if (ifaceSymbol is null || interfaces.ContainsKey(ifaceSymbol))
- continue;
-
- // The interface has no refit methods, but its base interfaces might
- var hasDerivedRefit = ifaceSymbol
- .AllInterfaces.SelectMany(i => i.GetMembers().OfType())
- .Any(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol));
-
- if (hasDerivedRefit)
- {
- // Add the interface to the generation list with an empty set of methods
- // The logic already looks for base refit methods
- interfaces.Add(ifaceSymbol, []);
- var isAnnotated =
- model.GetNullableContext(iface.SpanStart) == NullableContext.Enabled;
-
- interfaceToNullableEnabledMap[ifaceSymbol] = isAnnotated;
- }
- }
- }
-
- // Bail out if there aren't any interfaces to generate code for. This may be the case with transitives
- if (interfaces.Count == 0)
- return;
-
- var supportsNullable = options.LanguageVersion >= LanguageVersion.CSharp8;
-
- var keyCount = new Dictionary();
-
- var attributeText =
- @$"
-#pragma warning disable
-namespace {refitInternalNamespace}
-{{
- [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
- [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
- [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)]
- sealed class PreserveAttribute : global::System.Attribute
- {{
- //
- // Fields
- //
- public bool AllMembers;
-
- public bool Conditional;
- }}
-}}
-#pragma warning restore
-";
-
- compilation = compilation.AddSyntaxTrees(
- CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options)
- );
-
- // add the attribute text
- addSource(
- context,
- "PreserveAttribute.g.cs",
- SourceText.From(attributeText, Encoding.UTF8)
- );
-
- // get the newly bound attribute
- var preserveAttributeSymbol = compilation.GetTypeByMetadataName(
- $"{refitInternalNamespace}.PreserveAttribute"
- )!;
-
- var generatedClassText =
- @$"
-#pragma warning disable
-namespace Refit.Implementation
-{{
-
- ///
- [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
- [global::System.Diagnostics.DebuggerNonUserCode]
- [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}]
- [global::System.Reflection.Obfuscation(Exclude=true)]
- [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
- internal static partial class Generated
- {{
-#if NET5_0_OR_GREATER
- [System.Runtime.CompilerServices.ModuleInitializer]
- [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))]
- public static void Initialize()
- {{
- }}
-#endif
- }}
-}}
-#pragma warning restore
-";
- addSource(
- context,
- "Generated.g.cs",
- SourceText.From(generatedClassText, Encoding.UTF8)
- );
-
- compilation = compilation.AddSyntaxTrees(
- CSharpSyntaxTree.ParseText(
- SourceText.From(generatedClassText, Encoding.UTF8),
- options
- )
- );
-
- // group the fields by interface and generate the source
- foreach (var group in interfaces)
- {
- // each group is keyed by the Interface INamedTypeSymbol and contains the members
- // with a refit attribute on them. Types may contain other members, without the attribute, which we'll
- // need to check for and error out on
-
- var classSource = ProcessInterface(
- context,
- reportDiagnostic,
- group.Key,
- group.Value,
- preserveAttributeSymbol,
- disposableInterfaceSymbol,
- httpMethodBaseAttributeSymbol,
- supportsNullable,
- interfaceToNullableEnabledMap[group.Key]
- );
-
- var keyName = group.Key.Name;
- int value;
- while (keyCount.TryGetValue(keyName, out value))
- {
- keyName = $"{keyName}{++value}";
- }
- keyCount[keyName] = value;
-
- addSource(context, $"{keyName}.g.cs", SourceText.From(classSource, Encoding.UTF8));
- }
- }
-
- static string ProcessInterface(
- TContext context,
- Action reportDiagnostic,
- INamedTypeSymbol interfaceSymbol,
- List refitMethods,
- ISymbol preserveAttributeSymbol,
- ISymbol disposableInterfaceSymbol,
- INamedTypeSymbol httpMethodBaseAttributeSymbol,
- bool supportsNullable,
- bool nullableEnabled
- )
- {
- // Get the class name with the type parameters, then remove the namespace
- var className = interfaceSymbol.ToDisplayString();
- var lastDot = className.LastIndexOf('.');
- if (lastDot > 0)
- {
- className = className.Substring(lastDot + 1);
- }
- var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}";
-
- // Get the class name itself
- var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}";
- var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString();
-
- // if it's the global namespace, our lookup rules say it should be the same as the class name
- if (
- interfaceSymbol.ContainingNamespace != null
- && interfaceSymbol.ContainingNamespace.IsGlobalNamespace
- )
- {
- ns = string.Empty;
- }
-
- // Remove dots
- ns = ns!.Replace(".", "");
-
- // See what the nullable context is
-
-
- var source = new StringBuilder();
- if (supportsNullable)
- {
- source.Append("#nullable ");
-
- if (nullableEnabled)
- {
- source.Append("enable");
- }
- else
- {
- source.Append("disable");
- }
- }
-
- source.Append(
- $@"
-#pragma warning disable
-namespace Refit.Implementation
-{{
-
- partial class Generated
- {{
-
- ///
- [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
- [global::System.Diagnostics.DebuggerNonUserCode]
- [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}]
- [global::System.Reflection.Obfuscation(Exclude=true)]
- [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
- partial class {ns}{classDeclaration}
- : {interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{GenerateConstraints(interfaceSymbol.TypeParameters, false)}
-
- {{
- ///
- public global::System.Net.Http.HttpClient Client {{ get; }}
- readonly global::Refit.IRequestBuilder requestBuilder;
-
- ///
- public {ns}{classSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder)
- {{
- Client = client;
- this.requestBuilder = requestBuilder;
- }}
-"
- );
- // Get any other methods on the refit interfaces. We'll need to generate something for them and warn
- var nonRefitMethods = interfaceSymbol
- .GetMembers()
- .OfType()
- .Except(refitMethods, SymbolEqualityComparer.Default)
- .Cast()
- .ToList();
-
- // get methods for all inherited
- var derivedMethods = interfaceSymbol
- .AllInterfaces.SelectMany(i => i.GetMembers().OfType())
- .ToList();
-
- // Look for disposable
- var disposeMethod = derivedMethods.Find(
- m =>
- m.ContainingType?.Equals(
- disposableInterfaceSymbol,
- SymbolEqualityComparer.Default
- ) == true
- );
- if (disposeMethod != null)
- {
- //remove it from the derived methods list so we don't process it with the rest
- derivedMethods.Remove(disposeMethod);
- }
-
- // Pull out the refit methods from the derived types
- var derivedRefitMethods = derivedMethods
- .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol))
- .ToList();
- var derivedNonRefitMethods = derivedMethods
- .Except(derivedMethods, SymbolEqualityComparer.Default)
- .Cast()
- .ToList();
-
- var memberNames = new HashSet(interfaceSymbol.GetMembers().Select(x => x.Name));
-
- // Handle Refit Methods
- foreach (var method in refitMethods)
- {
- ProcessRefitMethod(source, method, true, memberNames);
- }
-
- foreach (var method in refitMethods.Concat(derivedRefitMethods))
- {
- ProcessRefitMethod(source, method, false, memberNames);
- }
-
- // Handle non-refit Methods that aren't static or properties or have a method body
- foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods))
- {
- if (
- method.IsStatic
- || method.MethodKind == MethodKind.PropertyGet
- || method.MethodKind == MethodKind.PropertySet
- || !method.IsAbstract
- ) // If an interface method has a body, it won't be abstract
- continue;
-
- ProcessNonRefitMethod(context, reportDiagnostic, source, method);
- }
-
- // Handle Dispose
- if (disposeMethod != null)
- {
- ProcessDisposableMethod(source, disposeMethod);
- }
-
- source.Append(
- @"
- }
- }
-}
-
-#pragma warning restore
-"
- );
- return source.ToString();
- }
-
- ///
- /// Generates the body of the Refit method
- ///
- ///
- ///
- /// True if directly from the type we're generating for, false for methods found on base interfaces
- /// Contains the unique member names in the interface scope.
- static void ProcessRefitMethod(
- StringBuilder source,
- IMethodSymbol methodSymbol,
- bool isTopLevel,
- HashSet memberNames
- )
- {
- var parameterTypesExpression = GenerateTypeParameterExpression(
- source,
- methodSymbol,
- memberNames
- );
-
- var returnType = methodSymbol.ReturnType.ToDisplayString(
- SymbolDisplayFormat.FullyQualifiedFormat
- );
- var (isAsync, @return, configureAwait) = methodSymbol.ReturnType.MetadataName switch
- {
- "Task" => (true, "await (", ").ConfigureAwait(false)"),
- "Task`1" or "ValueTask`1" => (true, "return await (", ").ConfigureAwait(false)"),
- _ => (false, "return ", ""),
- };
-
- WriteMethodOpening(source, methodSymbol, !isTopLevel, isAsync);
-
- // Build the list of args for the array
- var argList = new List();
- foreach (var param in methodSymbol.Parameters)
- {
- argList.Add($"@{param.MetadataName}");
- }
-
- // List of generic arguments
- var genericList = new List();
- foreach (var typeParam in methodSymbol.TypeParameters)
- {
- genericList.Add(
- $"typeof({typeParam.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
- );
- }
-
- var argumentsArrayString =
- argList.Count == 0
- ? "global::System.Array.Empty