diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs index 484440542..6f5c98c7a 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs @@ -69,714 +69,6 @@ out var refitInternalNamespace } #endif - /// - /// Generates the interface stubs. - /// - /// The type of the context. - /// The context. - /// The report diagnostic. - /// The add source. - /// The compilation. - /// The refit internal namespace. - /// The candidate methods. - /// The candidate interfaces. - /// - public void GenerateInterfaceStubs( - TContext context, - Action reportDiagnostic, - Action addSource, - CSharpCompilation compilation, - string? refitInternalNamespace, - ImmutableArray candidateMethods, - ImmutableArray candidateInterfaces - ) - { - refitInternalNamespace = - $"{refitInternalNamespace ?? string.Empty}RefitInternalGenerated"; - - // we're going to create a new compilation that contains the attribute. - // TODO: we should allow source generators to provide source during initialize, so that this step isn't required. - var options = (CSharpParseOptions)compilation.SyntaxTrees[0].Options; - - var disposableInterfaceSymbol = compilation.GetTypeByMetadataName( - "System.IDisposable" - )!; - var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName( - "Refit.HttpMethodAttribute" - ); - - if (httpMethodBaseAttributeSymbol == null) - { - reportDiagnostic( - context, - Diagnostic.Create(DiagnosticDescriptors.RefitNotReferenced, null) - ); - return; - } - - // Check the candidates and keep the ones we're actually interested in - -#pragma warning disable RS1024 // Compare symbols correctly - var interfaceToNullableEnabledMap = new Dictionary( - SymbolEqualityComparer.Default - ); -#pragma warning restore RS1024 // Compare symbols correctly - var methodSymbols = new List(); - foreach (var group in candidateMethods.GroupBy(m => m.SyntaxTree)) - { - var model = compilation.GetSemanticModel(group.Key); - foreach (var method in group) - { - // Get the symbol being declared by the method - var methodSymbol = model.GetDeclaredSymbol(method); - if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol)) - { - var isAnnotated = - compilation.Options.NullableContextOptions - == NullableContextOptions.Enable - || model.GetNullableContext(method.SpanStart) - == NullableContext.Enabled; - interfaceToNullableEnabledMap[methodSymbol!.ContainingType] = isAnnotated; - - methodSymbols.Add(methodSymbol!); - } - } - } - - var interfaces = methodSymbols - .GroupBy( - m => m.ContainingType, - SymbolEqualityComparer.Default - ) - .ToDictionary(g => g.Key, v => v.ToList()); - - // Look through the candidate interfaces - var interfaceSymbols = new List(); - foreach (var group in candidateInterfaces.GroupBy(i => i.SyntaxTree)) - { - var model = compilation.GetSemanticModel(group.Key); - foreach (var iface in group) - { - // get the symbol belonging to the interface - var ifaceSymbol = model.GetDeclaredSymbol(iface); - - // See if we already know about it, might be a dup - if (ifaceSymbol is null || interfaces.ContainsKey(ifaceSymbol)) - continue; - - // The interface has no refit methods, but its base interfaces might - var hasDerivedRefit = ifaceSymbol - .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) - .Any(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)); - - if (hasDerivedRefit) - { - // Add the interface to the generation list with an empty set of methods - // The logic already looks for base refit methods - interfaces.Add(ifaceSymbol, []); - var isAnnotated = - model.GetNullableContext(iface.SpanStart) == NullableContext.Enabled; - - interfaceToNullableEnabledMap[ifaceSymbol] = isAnnotated; - } - } - } - - // Bail out if there aren't any interfaces to generate code for. This may be the case with transitives - if (interfaces.Count == 0) - return; - - var supportsNullable = options.LanguageVersion >= LanguageVersion.CSharp8; - - var keyCount = new Dictionary(); - - var attributeText = - @$" -#pragma warning disable -namespace {refitInternalNamespace} -{{ - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] - sealed class PreserveAttribute : global::System.Attribute - {{ - // - // Fields - // - public bool AllMembers; - - public bool Conditional; - }} -}} -#pragma warning restore -"; - - compilation = compilation.AddSyntaxTrees( - CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options) - ); - - // add the attribute text - addSource( - context, - "PreserveAttribute.g.cs", - SourceText.From(attributeText, Encoding.UTF8) - ); - - // get the newly bound attribute - var preserveAttributeSymbol = compilation.GetTypeByMetadataName( - $"{refitInternalNamespace}.PreserveAttribute" - )!; - - var generatedClassText = - @$" -#pragma warning disable -namespace Refit.Implementation -{{ - - /// - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.Diagnostics.DebuggerNonUserCode] - [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}] - [global::System.Reflection.Obfuscation(Exclude=true)] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - internal static partial class Generated - {{ -#if NET5_0_OR_GREATER - [System.Runtime.CompilerServices.ModuleInitializer] - [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] - public static void Initialize() - {{ - }} -#endif - }} -}} -#pragma warning restore -"; - addSource( - context, - "Generated.g.cs", - SourceText.From(generatedClassText, Encoding.UTF8) - ); - - compilation = compilation.AddSyntaxTrees( - CSharpSyntaxTree.ParseText( - SourceText.From(generatedClassText, Encoding.UTF8), - options - ) - ); - - // group the fields by interface and generate the source - foreach (var group in interfaces) - { - // each group is keyed by the Interface INamedTypeSymbol and contains the members - // with a refit attribute on them. Types may contain other members, without the attribute, which we'll - // need to check for and error out on - - var classSource = ProcessInterface( - context, - reportDiagnostic, - group.Key, - group.Value, - preserveAttributeSymbol, - disposableInterfaceSymbol, - httpMethodBaseAttributeSymbol, - supportsNullable, - interfaceToNullableEnabledMap[group.Key] - ); - - var keyName = group.Key.Name; - int value; - while (keyCount.TryGetValue(keyName, out value)) - { - keyName = $"{keyName}{++value}"; - } - keyCount[keyName] = value; - - addSource(context, $"{keyName}.g.cs", SourceText.From(classSource, Encoding.UTF8)); - } - } - - static string ProcessInterface( - TContext context, - Action reportDiagnostic, - INamedTypeSymbol interfaceSymbol, - List refitMethods, - ISymbol preserveAttributeSymbol, - ISymbol disposableInterfaceSymbol, - INamedTypeSymbol httpMethodBaseAttributeSymbol, - bool supportsNullable, - bool nullableEnabled - ) - { - // Get the class name with the type parameters, then remove the namespace - var className = interfaceSymbol.ToDisplayString(); - var lastDot = className.LastIndexOf('.'); - if (lastDot > 0) - { - className = className.Substring(lastDot + 1); - } - var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}"; - - // Get the class name itself - var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}"; - var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString(); - - // if it's the global namespace, our lookup rules say it should be the same as the class name - if ( - interfaceSymbol.ContainingNamespace != null - && interfaceSymbol.ContainingNamespace.IsGlobalNamespace - ) - { - ns = string.Empty; - } - - // Remove dots - ns = ns!.Replace(".", ""); - - // See what the nullable context is - - - var source = new StringBuilder(); - if (supportsNullable) - { - source.Append("#nullable "); - - if (nullableEnabled) - { - source.Append("enable"); - } - else - { - source.Append("disable"); - } - } - - source.Append( - $@" -#pragma warning disable -namespace Refit.Implementation -{{ - - partial class Generated - {{ - - /// - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.Diagnostics.DebuggerNonUserCode] - [{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}] - [global::System.Reflection.Obfuscation(Exclude=true)] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - partial class {ns}{classDeclaration} - : {interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{GenerateConstraints(interfaceSymbol.TypeParameters, false)} - - {{ - /// - public global::System.Net.Http.HttpClient Client {{ get; }} - readonly global::Refit.IRequestBuilder requestBuilder; - - /// - public {ns}{classSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) - {{ - Client = client; - this.requestBuilder = requestBuilder; - }} -" - ); - // Get any other methods on the refit interfaces. We'll need to generate something for them and warn - var nonRefitMethods = interfaceSymbol - .GetMembers() - .OfType() - .Except(refitMethods, SymbolEqualityComparer.Default) - .Cast() - .ToList(); - - // get methods for all inherited - var derivedMethods = interfaceSymbol - .AllInterfaces.SelectMany(i => i.GetMembers().OfType()) - .ToList(); - - // Look for disposable - var disposeMethod = derivedMethods.Find( - m => - m.ContainingType?.Equals( - disposableInterfaceSymbol, - SymbolEqualityComparer.Default - ) == true - ); - if (disposeMethod != null) - { - //remove it from the derived methods list so we don't process it with the rest - derivedMethods.Remove(disposeMethod); - } - - // Pull out the refit methods from the derived types - var derivedRefitMethods = derivedMethods - .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)) - .ToList(); - var derivedNonRefitMethods = derivedMethods - .Except(derivedMethods, SymbolEqualityComparer.Default) - .Cast() - .ToList(); - - var memberNames = new HashSet(interfaceSymbol.GetMembers().Select(x => x.Name)); - - // Handle Refit Methods - foreach (var method in refitMethods) - { - ProcessRefitMethod(source, method, true, memberNames); - } - - foreach (var method in refitMethods.Concat(derivedRefitMethods)) - { - ProcessRefitMethod(source, method, false, memberNames); - } - - // Handle non-refit Methods that aren't static or properties or have a method body - foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods)) - { - if ( - method.IsStatic - || method.MethodKind == MethodKind.PropertyGet - || method.MethodKind == MethodKind.PropertySet - || !method.IsAbstract - ) // If an interface method has a body, it won't be abstract - continue; - - ProcessNonRefitMethod(context, reportDiagnostic, source, method); - } - - // Handle Dispose - if (disposeMethod != null) - { - ProcessDisposableMethod(source, disposeMethod); - } - - source.Append( - @" - } - } -} - -#pragma warning restore -" - ); - return source.ToString(); - } - - /// - /// Generates the body of the Refit method - /// - /// - /// - /// True if directly from the type we're generating for, false for methods found on base interfaces - /// Contains the unique member names in the interface scope. - static void ProcessRefitMethod( - StringBuilder source, - IMethodSymbol methodSymbol, - bool isTopLevel, - HashSet memberNames - ) - { - var parameterTypesExpression = GenerateTypeParameterExpression( - source, - methodSymbol, - memberNames - ); - - var returnType = methodSymbol.ReturnType.ToDisplayString( - SymbolDisplayFormat.FullyQualifiedFormat - ); - var (isAsync, @return, configureAwait) = methodSymbol.ReturnType.MetadataName switch - { - "Task" => (true, "await (", ").ConfigureAwait(false)"), - "Task`1" or "ValueTask`1" => (true, "return await (", ").ConfigureAwait(false)"), - _ => (false, "return ", ""), - }; - - WriteMethodOpening(source, methodSymbol, !isTopLevel, isAsync); - - // Build the list of args for the array - var argList = new List(); - foreach (var param in methodSymbol.Parameters) - { - argList.Add($"@{param.MetadataName}"); - } - - // List of generic arguments - var genericList = new List(); - foreach (var typeParam in methodSymbol.TypeParameters) - { - genericList.Add( - $"typeof({typeParam.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ); - } - - var argumentsArrayString = - argList.Count == 0 - ? "global::System.Array.Empty()" - : $"new object[] {{ {string.Join(", ", argList)} }}"; - - var genericString = - genericList.Count > 0 - ? $", new global::System.Type[] {{ {string.Join(", ", genericList)} }}" - : string.Empty; - - source.Append( - @$" - var ______arguments = {argumentsArrayString}; - var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesExpression}{genericString} ); - - {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; -" - ); - - WriteMethodClosing(source); - } - - static void ProcessDisposableMethod(StringBuilder source, IMethodSymbol methodSymbol) - { - WriteMethodOpening(source, methodSymbol, true); - - source.Append( - @" - Client?.Dispose(); -" - ); - - WriteMethodClosing(source); - } - - static string GenerateConstraints( - ImmutableArray typeParameters, - bool isOverrideOrExplicitImplementation - ) - { - var source = new StringBuilder(); - // Need to loop over the constraints and create them - foreach (var typeParameter in typeParameters) - { - WriteConstraintsForTypeParameter( - source, - typeParameter, - isOverrideOrExplicitImplementation - ); - } - - return source.ToString(); - } - - static void WriteConstraintsForTypeParameter( - StringBuilder source, - ITypeParameterSymbol typeParameter, - bool isOverrideOrExplicitImplementation - ) - { - // Explicit interface implementations and overrides can only have class or struct constraints - - var parameters = new List(); - if (typeParameter.HasReferenceTypeConstraint) - { - parameters.Add("class"); - } - if (typeParameter.HasUnmanagedTypeConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("unmanaged"); - } - - // unmanaged constraints are both structs and unmanaged so the struct constraint is redundant - if (typeParameter.HasValueTypeConstraint && !typeParameter.HasUnmanagedTypeConstraint) - { - parameters.Add("struct"); - } - if (typeParameter.HasNotNullConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("notnull"); - } - if (!isOverrideOrExplicitImplementation) - { - foreach (var typeConstraint in typeParameter.ConstraintTypes) - { - parameters.Add( - typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) - ); - } - } - - // new constraint has to be last - if (typeParameter.HasConstructorConstraint && !isOverrideOrExplicitImplementation) - { - parameters.Add("new()"); - } - - if (parameters.Count > 0) - { - source.Append( - @$" - where {typeParameter.Name} : {string.Join(", ", parameters)}" - ); - } - } - - static void ProcessNonRefitMethod( - TContext context, - Action reportDiagnostic, - StringBuilder source, - IMethodSymbol methodSymbol - ) - { - WriteMethodOpening(source, methodSymbol, true); - - source.Append( - @" - throw new global::System.NotImplementedException(""Either this method has no Refit HTTP method attribute or you've used something other than a string literal for the 'path' argument.""); -" - ); - - WriteMethodClosing(source); - - foreach (var location in methodSymbol.Locations) - { - var diagnostic = Diagnostic.Create( - DiagnosticDescriptors.InvalidRefitMember, - location, - methodSymbol.ContainingType.Name, - methodSymbol.Name - ); - reportDiagnostic(context, diagnostic); - } - } - - static string GenerateTypeParameterExpression( - StringBuilder source, - IMethodSymbol methodSymbol, - HashSet memberNames - ) - { - // use Array.Empty if method has no parameters. - if (methodSymbol.Parameters.Length == 0) - return "global::System.Array.Empty()"; - - // if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls. - if (methodSymbol.Parameters.Any(x => ContainsTypeParameter(x.Type))) - { - var typeEnumerable = methodSymbol.Parameters.Select( - param => - $"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ); - return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}"; - } - - // find a name and generate field declaration. - var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames); - var types = string.Join( - ", ", - methodSymbol.Parameters.Select( - x => - $"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})" - ) - ); - source.Append( - $$""" - - - private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; - """ - ); - - return typeParameterFieldName; - - static bool ContainsTypeParameter(ITypeSymbol symbol) - { - if (symbol is ITypeParameterSymbol) - return true; - - if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType) - return false; - - foreach (var typeArg in namedType.TypeArguments) - { - if (ContainsTypeParameter(typeArg)) - return true; - } - - return false; - } - } - - static void WriteMethodOpening( - StringBuilder source, - IMethodSymbol methodSymbol, - bool isExplicitInterface, - bool isAsync = false - ) - { - var visibility = !isExplicitInterface ? "public " : string.Empty; - var async = isAsync ? "async " : ""; - - source.Append( - @$" - - /// - {visibility}{async}{methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} " - ); - - if (isExplicitInterface) - { - source.Append( - @$"{methodSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}." - ); - } - source.Append( - @$"{methodSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}(" - ); - - if (methodSymbol.Parameters.Length > 0) - { - var list = new List(); - foreach (var param in methodSymbol.Parameters) - { - var annotation = - !param.Type.IsValueType - && param.NullableAnnotation == NullableAnnotation.Annotated; - - list.Add( - $@"{param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{(annotation ? '?' : string.Empty)} @{param.MetadataName}" - ); - } - - source.Append(string.Join(", ", list)); - } - - source.Append( - @$"){GenerateConstraints(methodSymbol.TypeParameters, isExplicitInterface)} - {{" - ); - } - - static void WriteMethodClosing(StringBuilder source) => source.Append(@" }"); - - static string UniqueName(string name, HashSet methodNames) - { - var candidateName = name; - var counter = 0; - while (methodNames.Contains(candidateName)) - { - candidateName = $"{name}{counter}"; - counter++; - } - - methodNames.Add(candidateName); - return candidateName; - } - - static bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttribute) - { - return methodSymbol - ?.GetAttributes() - .Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttribute) == true) - == true; - } - #if ROSLYN_4 /// @@ -867,23 +159,6 @@ out var refitInternalNamespace Emitter.EmitSharedCode(model, (name, code) => spc.AddSource(name, code)); } ); - - // context.RegisterSourceOutput( - // inputs, - // (context, collectedValues) => - // { - // GenerateInterfaceStubs( - // context, - // static (context, diagnostic) => context.ReportDiagnostic(diagnostic), - // static (context, hintName, sourceText) => - // context.AddSource(hintName, sourceText), - // (CSharpCompilation)collectedValues.compilation, - // collectedValues.refitInternalNamespace, - // collectedValues.candidateMethods, - // collectedValues.candidateInterfaces - // ); - // } - // ); } #else diff --git a/InterfaceStubGenerator.Shared/Parser.cs b/InterfaceStubGenerator.Shared/Parser.cs index 0f4e78a23..aafd7c6ac 100644 --- a/InterfaceStubGenerator.Shared/Parser.cs +++ b/InterfaceStubGenerator.Shared/Parser.cs @@ -157,14 +157,13 @@ sealed class PreserveAttribute : global::System.Attribute #pragma warning restore "; - compilation = compilation.AddSyntaxTrees( - CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options) - ); - // TODO: Delete? // Is it necessary to add the attributes to the compilation now, does it affect the users ide experience? // Is it needed in order to get the preserve attribute display name. // Will the compilation ever change this. + compilation = compilation.AddSyntaxTrees( + CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options) + ); // get the newly bound attribute var preserveAttributeSymbol = compilation.GetTypeByMetadataName( @@ -314,7 +313,7 @@ bool nullableEnabled ) // If an interface method has a body, it won't be abstract continue; - nonRefitMethodModelList.Add(ProcessNonRefitMethod(method, diagnostics)); + nonRefitMethodModelList.Add(ParseNonRefitMethod(method, diagnostics)); } var nonRefitMethodModels = nonRefitMethodModelList.ToImmutableEquatableArray(); @@ -345,7 +344,7 @@ bool nullableEnabled ); } - private static MethodModel ProcessNonRefitMethod( + private static MethodModel ParseNonRefitMethod( IMethodSymbol methodSymbol, List diagnostics )