diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems index 5589f69c2..3c02c35ff 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems @@ -6,10 +6,12 @@ b591423d-f92d-4e00-b0eb-615c9853506c - InterfaceStubGenerator.Shared + Refit.Generator + + \ No newline at end of file diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs index e59526513..1538f1f1f 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs @@ -86,6 +86,7 @@ public void GenerateInterfaceStubs( return; } + var refitMetadata = new RefitMetadata(disposableInterfaceSymbol, httpMethodBaseAttributeSymbol); // Check the candidates and keep the ones we're actually interested in @@ -100,7 +101,7 @@ public void GenerateInterfaceStubs( { // Get the symbol being declared by the method var methodSymbol = model.GetDeclaredSymbol(method); - if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol)) + if (refitMetadata.IsRefitMethod(methodSymbol)) { var isAnnotated = compilation.Options.NullableContextOptions == NullableContextOptions.Enable || model.GetNullableContext(method.SpanStart) == NullableContext.Enabled; @@ -131,7 +132,7 @@ public void GenerateInterfaceStubs( // The interface has no refit methods, but its base interfaces might var hasDerivedRefit = ifaceSymbol.AllInterfaces .SelectMany(i => i.GetMembers().OfType()) - .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)) + .Where(refitMetadata.IsRefitMethod) .Any(); if (hasDerivedRefit) @@ -212,18 +213,15 @@ internal static partial class Generated // each group is keyed by the Interface INamedTypeSymbol and contains the members // with a refit attribute on them. Types may contain other members, without the attribute, which we'll // need to check for and error out on - + var model = new RefitClientModel(group.Key, group.Value, refitMetadata); var classSource = ProcessInterface(context, reportDiagnostic, - group.Key, - group.Value, + model, preserveAttributeSymbol, - disposableInterfaceSymbol, - httpMethodBaseAttributeSymbol, supportsNullable, - interfaceToNullableEnabledMap[group.Key]); + interfaceToNullableEnabledMap[model.RefitInterface]); - var keyName = group.Key.Name; + var keyName = model.FileName; if(keyCount.TryGetValue(keyName, out var value)) { keyName = $"{keyName}{++value}"; @@ -237,37 +235,13 @@ internal static partial class Generated string ProcessInterface(TContext context, Action reportDiagnostic, - INamedTypeSymbol interfaceSymbol, - List refitMethods, + RefitClientModel interfaceModel, ISymbol preserveAttributeSymbol, - ISymbol disposableInterfaceSymbol, - INamedTypeSymbol httpMethodBaseAttributeSymbol, bool supportsNullable, bool nullableEnabled) { - - // Get the class name with the type parameters, then remove the namespace - var className = interfaceSymbol.ToDisplayString(); - var lastDot = className.LastIndexOf('.'); - if(lastDot > 0) - { - className = className.Substring(lastDot+1); - } - var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}"; - - - // Get the class name itself - var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}"; - var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString(); - - // if it's the global namespace, our lookup rules say it should be the same as the class name - if(interfaceSymbol.ContainingNamespace != null && interfaceSymbol.ContainingNamespace.IsGlobalNamespace) - { - ns = string.Empty; - } - - // Remove dots - ns = ns!.Replace(".", ""); + INamedTypeSymbol interfaceSymbol = interfaceModel.RefitInterface; + List refitMethods = interfaceModel.RefitMethods; // See what the nullable context is @@ -301,8 +275,8 @@ partial class Generated [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}] [global::System.Reflection.Obfuscation(Exclude=true)] [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - partial class {ns}{classDeclaration} - : {interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{GenerateConstraints(interfaceSymbol.TypeParameters, false)} + partial class {interfaceModel.NamespacePrefix}{interfaceModel.ClassDeclaration} + : {interfaceModel.BaseInterfaceDeclaration}{GenerateConstraints(interfaceSymbol.TypeParameters, false)} {{ /// @@ -310,30 +284,13 @@ partial class {ns}{classDeclaration} readonly global::Refit.IRequestBuilder requestBuilder; /// - public {ns}{classSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) + public {interfaceModel.NamespacePrefix}{interfaceModel.ClassSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) {{ Client = client; this.requestBuilder = requestBuilder; }} "); - // Get any other methods on the refit interfaces. We'll need to generate something for them and warn - var nonRefitMethods = interfaceSymbol.GetMembers().OfType().Except(refitMethods, SymbolEqualityComparer.Default).Cast().ToList(); - - // get methods for all inherited - var derivedMethods = interfaceSymbol.AllInterfaces.SelectMany(i => i.GetMembers().OfType()).ToList(); - - // Look for disposable - var disposeMethod = derivedMethods.Find(m => m.ContainingType?.Equals(disposableInterfaceSymbol, SymbolEqualityComparer.Default) == true); - if(disposeMethod != null) - { - //remove it from the derived methods list so we don't process it with the rest - derivedMethods.Remove(disposeMethod); - } - - // Pull out the refit methods from the derived types - var derivedRefitMethods = derivedMethods.Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)).ToList(); - var derivedNonRefitMethods = derivedMethods.Except(derivedMethods, SymbolEqualityComparer.Default).Cast().ToList(); // Handle Refit Methods foreach(var method in refitMethods) @@ -341,28 +298,22 @@ partial class {ns}{classDeclaration} ProcessRefitMethod(source, method, true); } - foreach (var method in refitMethods.Concat(derivedRefitMethods)) + foreach (var method in interfaceModel.AllRefitMethods) { ProcessRefitMethod(source, method, false); } // Handle non-refit Methods that aren't static or properties or have a method body - foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods)) + foreach (var method in interfaceModel.NonRefitMethods) { - if (method.IsStatic || - method.MethodKind == MethodKind.PropertyGet || - method.MethodKind == MethodKind.PropertySet || - !method.IsAbstract) // If an interface method has a body, it won't be abstract - continue; - ProcessNonRefitMethod(context, reportDiagnostic, source, method); } // Handle Dispose - if(disposeMethod != null) + if (interfaceModel.DisposeMethod != null) { - ProcessDisposableMethod(source, disposeMethod); + ProcessDisposableMethod(source, interfaceModel.DisposeMethod); } source.Append(@" @@ -432,7 +383,7 @@ string GenerateConstraints(ImmutableArray typeParameters, { var source = new StringBuilder(); // Need to loop over the constraints and create them - foreach(var typeParameter in typeParameters) + foreach (var typeParameter in typeParameters) { WriteConstraitsForTypeParameter(source, typeParameter, isOverrideOrExplicitImplementation); } @@ -445,7 +396,7 @@ void WriteConstraitsForTypeParameter(StringBuilder source, ITypeParameterSymbol // Explicit interface implementations and ovverrides can only have class or struct constraints var parameters = new List(); - if(typeParameter.HasReferenceTypeConstraint) + if (typeParameter.HasReferenceTypeConstraint) { parameters.Add("class"); } @@ -468,7 +419,7 @@ void WriteConstraitsForTypeParameter(StringBuilder source, ITypeParameterSymbol parameters.Add(typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); } } - + // new constraint has to be last if (typeParameter.HasConstructorConstraint && !isOverrideOrExplicitImplementation) { @@ -534,12 +485,6 @@ void WriteMethodOpening(StringBuilder source, IMethodSymbol methodSymbol, bool i void WriteMethodClosing(StringBuilder source) => source.Append(@" }"); - - bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttibute) - { - return methodSymbol?.GetAttributes().Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttibute) == true) == true; - } - #if ROSLYN_4 public void Initialize(IncrementalGeneratorInitializationContext context) diff --git a/InterfaceStubGenerator.Shared/RefitClientModel.cs b/InterfaceStubGenerator.Shared/RefitClientModel.cs new file mode 100644 index 000000000..b3192d05f --- /dev/null +++ b/InterfaceStubGenerator.Shared/RefitClientModel.cs @@ -0,0 +1,100 @@ +using System.Collections.Generic; +using System.Linq; + +using Microsoft.CodeAnalysis; + +namespace Refit.Generator; + +internal class RefitClientModel +{ + readonly RefitMetadata refitMetadata; + + public RefitClientModel(INamedTypeSymbol refitInterface, List refitMethods, RefitMetadata refitMetadata) + { + RefitInterface = refitInterface; + RefitMethods = refitMethods; + this.refitMetadata = refitMetadata; + + // Get any other methods on the refit interfaces. We'll need to generate something for them and warn + var nonRefitMethods = refitInterface.GetMembers().OfType().Except(refitMethods, SymbolEqualityComparer.Default).Cast().ToList(); + + // get methods for all refitInterface + var derivedMethods = refitInterface.AllInterfaces.SelectMany(i => i.GetMembers().OfType()).ToList(); + + // Look for disposable + DisposeMethod = derivedMethods.Find(m => m.ContainingType?.Equals(refitMetadata.DisposableInterfaceSymbol, SymbolEqualityComparer.Default) == true); + if (DisposeMethod != null) + { + //remove it from the derived methods list so we don't process it with the rest + derivedMethods.Remove(DisposeMethod); + } + + // Pull out the refit methods from the derived types + var derivedRefitMethods = derivedMethods.Where(refitMetadata.IsRefitMethod).ToList(); + var derivedNonRefitMethods = derivedMethods.Except(derivedMethods, SymbolEqualityComparer.Default).Cast().ToList(); + + AllRefitMethods = refitMethods.Concat(derivedRefitMethods); + NonRefitMethods = nonRefitMethods.Concat(derivedNonRefitMethods) + .Where(static method => + { + return !(method.IsStatic || + method.MethodKind == MethodKind.PropertyGet || + method.MethodKind == MethodKind.PropertySet || + !method.IsAbstract); + }); + } + + public INamedTypeSymbol RefitInterface { get; } + public List RefitMethods { get; } + public IEnumerable AllRefitMethods { get; } + public IEnumerable NonRefitMethods { get; } + + public string FileName => RefitInterface.Name; + + public string ClassDeclaration + { + get + { + // Get the class name with the type parameters, then remove the namespace + var className = RefitInterface.ToDisplayString(); + var lastDot = className.LastIndexOf('.'); + if (lastDot > 0) + { + className = className.Substring(lastDot + 1); + } + var classDeclaration = $"{RefitInterface.ContainingType?.Name}{className}"; + return classDeclaration; + } + } + + public string ClassSuffix + { + get + { + // Get the class name itself + var classSuffix = $"{RefitInterface.ContainingType?.Name}{RefitInterface.Name}"; + return classSuffix; + } + } + + public string NamespacePrefix + { + get + { + var ns = RefitInterface.ContainingNamespace?.ToDisplayString(); + + // if it's the global namespace, our lookup rules say it should be the same as the class name + if (RefitInterface.ContainingNamespace != null && RefitInterface.ContainingNamespace.IsGlobalNamespace) + { + return string.Empty; + } + + // Remove dots + ns = ns!.Replace(".", ""); + return ns; + } + } + public string BaseInterfaceDeclaration => $"{RefitInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}"; + + public IMethodSymbol DisposeMethod { get; } +} diff --git a/InterfaceStubGenerator.Shared/RefitMetadata.cs b/InterfaceStubGenerator.Shared/RefitMetadata.cs new file mode 100644 index 000000000..b6fa0bbd0 --- /dev/null +++ b/InterfaceStubGenerator.Shared/RefitMetadata.cs @@ -0,0 +1,22 @@ +using System.Linq; + +using Microsoft.CodeAnalysis; + +namespace Refit.Generator; + +internal class RefitMetadata +{ + public RefitMetadata(INamedTypeSymbol? disposableInterfaceSymbol, INamedTypeSymbol httpMethodBaseAttributeSymbol) + { + DisposableInterfaceSymbol = disposableInterfaceSymbol; + HttpMethodBaseAttributeSymbol = httpMethodBaseAttributeSymbol; + } + + public INamedTypeSymbol? DisposableInterfaceSymbol { get; } + public INamedTypeSymbol HttpMethodBaseAttributeSymbol { get; } + + public bool IsRefitMethod(IMethodSymbol? methodSymbol) + { + return methodSymbol?.GetAttributes().Any(ad => ad.AttributeClass?.InheritsFromOrEquals(HttpMethodBaseAttributeSymbol) == true) == true; + } +}