From 3201fb9ce5a5b42d0a328c83640988691d7d4e24 Mon Sep 17 00:00:00 2001 From: Timothy Makkison Date: Tue, 27 Jun 2023 22:52:32 +0100 Subject: [PATCH] chore: refactor MapperGenerator to cache outputs. --- .../SourceGeneratorBenchmarks.cs | 4 +- .../Descriptors/DescriptorBuilder.cs | 8 +- .../SimpleMappingBuilderContext.cs | 22 ++-- src/Riok.Mapperly/Emit/SourceEmitter.cs | 5 +- .../IncrementalValuesProviderExtensions.cs | 23 ++++ src/Riok.Mapperly/MapperGenerator.cs | 51 ++++++-- src/Riok.Mapperly/MapperNode.cs | 27 ++++ src/Riok.Mapperly/MapperResults.cs | 50 +++++++ .../Generator/IncrementalGeneratorTest.cs | 123 ++++++++++++++++++ test/Riok.Mapperly.Tests/TestHelper.cs | 40 ++++-- test/Riok.Mapperly.Tests/TestSourceBuilder.cs | 7 + 11 files changed, 315 insertions(+), 45 deletions(-) create mode 100644 src/Riok.Mapperly/MapperNode.cs create mode 100644 src/Riok.Mapperly/MapperResults.cs create mode 100644 test/Riok.Mapperly.Tests/Generator/IncrementalGeneratorTest.cs diff --git a/benchmarks/Riok.Mapperly.Benchmarks/SourceGeneratorBenchmarks.cs b/benchmarks/Riok.Mapperly.Benchmarks/SourceGeneratorBenchmarks.cs index fbb8216f82..e2047ea913 100644 --- a/benchmarks/Riok.Mapperly.Benchmarks/SourceGeneratorBenchmarks.cs +++ b/benchmarks/Riok.Mapperly.Benchmarks/SourceGeneratorBenchmarks.cs @@ -19,10 +19,10 @@ public class SourceGeneratorBenchmarks private MSBuildWorkspace? _workspace; - private CSharpGeneratorDriver? _sampleDriver; + private GeneratorDriver? _sampleDriver; private Compilation? _sampleCompilation; - private CSharpGeneratorDriver? _largeDriver; + private GeneratorDriver? _largeDriver; private Compilation? _largeCompilation; public SourceGeneratorBenchmarks() diff --git a/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs b/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs index 423866fb9d..bdb0d34689 100644 --- a/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs @@ -17,11 +17,11 @@ public class DescriptorBuilder private readonly MethodNameBuilder _methodNameBuilder = new(); private readonly MappingBodyBuilder _mappingBodyBuilder; private readonly SimpleMappingBuilderContext _builderContext; + private readonly List _diagnostics = new(); private ObjectFactoryCollection _objectFactories = ObjectFactoryCollection.Empty; public DescriptorBuilder( - SourceProductionContext sourceContext, Compilation compilation, ClassDeclarationSyntax mapperSyntax, INamedTypeSymbol mapperSymbol, @@ -35,13 +35,13 @@ WellKnownTypes wellKnownTypes new MapperConfiguration(wellKnownTypes, mapperSymbol), wellKnownTypes, _mapperDescriptor, - sourceContext, + _diagnostics, new MappingBuilder(_mappings), new ExistingTargetMappingBuilder(_mappings) ); } - public MapperDescriptor Build() + public (MapperDescriptor descriptor, List errors) Build() { ReserveMethodNames(); ExtractObjectFactories(); @@ -50,7 +50,7 @@ public MapperDescriptor Build() BuildMappingMethodNames(); BuildReferenceHandlingParameters(); AddMappingsToDescriptor(); - return _mapperDescriptor; + return (_mapperDescriptor, _diagnostics); } private void ExtractObjectFactories() diff --git a/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs b/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs index 7d66a09fbe..eaecc73bc0 100644 --- a/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs +++ b/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs @@ -11,7 +11,7 @@ namespace Riok.Mapperly.Descriptors; public class SimpleMappingBuilderContext { private readonly MapperDescriptor _descriptor; - private readonly SourceProductionContext _context; + private readonly List _diagnostics; private readonly MapperConfiguration _configuration; public SimpleMappingBuilderContext( @@ -19,7 +19,7 @@ public SimpleMappingBuilderContext( MapperConfiguration configuration, WellKnownTypes types, MapperDescriptor descriptor, - SourceProductionContext context, + List diagnostics, MappingBuilder mappingBuilder, ExistingTargetMappingBuilder existingTargetMappingBuilder ) @@ -28,7 +28,7 @@ ExistingTargetMappingBuilder existingTargetMappingBuilder Types = types; _configuration = configuration; _descriptor = descriptor; - _context = context; + _diagnostics = diagnostics; MappingBuilder = mappingBuilder; ExistingTargetMappingBuilder = existingTargetMappingBuilder; } @@ -39,7 +39,7 @@ protected SimpleMappingBuilderContext(SimpleMappingBuilderContext ctx) ctx._configuration, ctx.Types, ctx._descriptor, - ctx._context, + ctx._diagnostics, ctx.MappingBuilder, ctx.ExistingTargetMappingBuilder ) { } @@ -57,14 +57,12 @@ protected SimpleMappingBuilderContext(SimpleMappingBuilderContext ctx) public virtual bool IsConversionEnabled(MappingConversionType conversionType) => MapperConfiguration.EnabledConversions.HasFlag(conversionType); - public void ReportDiagnostic(DiagnosticDescriptor descriptor, ISymbol? location, params object[] messageArgs) => - ReportDiagnostic(descriptor, location?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(), messageArgs); - - public void ReportDiagnostic(DiagnosticDescriptor descriptor, SyntaxNode? location, params object[] messageArgs) => - ReportDiagnostic(descriptor, location?.GetLocation(), messageArgs); + public void ReportDiagnostic(DiagnosticDescriptor descriptor, ISymbol? location, params object[] messageArgs) + { + var syntaxNode = location?.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(); + var nodeLocation = syntaxNode?.GetLocation(); + _diagnostics.Add(Diagnostic.Create(descriptor, nodeLocation ?? _descriptor.Syntax.GetLocation(), messageArgs)); + } protected MappingConfiguration ReadConfiguration(IMethodSymbol? userSymbol) => _configuration.ForMethod(userSymbol); - - private void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location, params object[] messageArgs) => - _context.ReportDiagnostic(Diagnostic.Create(descriptor, location ?? _descriptor.Syntax.GetLocation(), messageArgs)); } diff --git a/src/Riok.Mapperly/Emit/SourceEmitter.cs b/src/Riok.Mapperly/Emit/SourceEmitter.cs index 0412adb66e..e68fbbdce0 100644 --- a/src/Riok.Mapperly/Emit/SourceEmitter.cs +++ b/src/Riok.Mapperly/Emit/SourceEmitter.cs @@ -18,10 +18,7 @@ public static CompilationUnitSyntax Build(MapperDescriptor descriptor) member = WrapInClassesAsNeeded(descriptor.Symbol, member); member = WrapInNamespaceIfNeeded(descriptor.Namespace, member); - return CompilationUnit() - .WithMembers(SingletonList(member)) - .WithLeadingTrivia(Comment("// "), Nullable(true)) - .NormalizeWhitespace(); + return CompilationUnit().WithMembers(SingletonList(member)).WithLeadingTrivia(Comment("// "), Nullable(true)); } private static IEnumerable BuildMembers(MapperDescriptor descriptor, SourceEmitterContext sourceEmitterContext) diff --git a/src/Riok.Mapperly/Helpers/IncrementalValuesProviderExtensions.cs b/src/Riok.Mapperly/Helpers/IncrementalValuesProviderExtensions.cs index 993baca390..32402346de 100644 --- a/src/Riok.Mapperly/Helpers/IncrementalValuesProviderExtensions.cs +++ b/src/Riok.Mapperly/Helpers/IncrementalValuesProviderExtensions.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; namespace Riok.Mapperly.Helpers; @@ -10,4 +11,26 @@ public static IncrementalValuesProvider WhereNotNull(this Incr return source.Where(x => x != null); #nullable enable } + + /// + /// Registers an output node into an to output diagnostics. + /// + /// The input instance. + /// The input sequence of diagnostics. + public static void ReportDiagnostics( + this IncrementalGeneratorInitializationContext context, + IncrementalValueProvider> diagnostics + ) + { + context.RegisterSourceOutput( + diagnostics, + static (context, diagnostics) => + { + foreach (var diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic); + } + } + ); + } } diff --git a/src/Riok.Mapperly/MapperGenerator.cs b/src/Riok.Mapperly/MapperGenerator.cs index b848ad3738..e9be66ddca 100644 --- a/src/Riok.Mapperly/MapperGenerator.cs +++ b/src/Riok.Mapperly/MapperGenerator.cs @@ -22,25 +22,55 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var mapperClassDeclarations = SyntaxProvider.GetClassDeclarations(context); var compilationAndMappers = context.CompilationProvider.Combine(mapperClassDeclarations.Collect()); - context.RegisterImplementationSourceOutput(compilationAndMappers, static (spc, source) => Execute(source.Left, source.Right, spc)); + var mappersWithDiagnostics = compilationAndMappers.Select( + static (x, cancellationToken) => BuildDescriptors(x.Left, x.Right, cancellationToken) + ); + + // output the diagnostics +#if NET7_0_OR_GREATER + context.ReportDiagnostics(mappersWithDiagnostics.Select(static (source, _) => source.Diagnostics).WithTrackingName("diagnostics")); +#else + context.ReportDiagnostics(mappersWithDiagnostics.Select(static (source, _) => source.Diagnostics)); +#endif + + // split into mapper name pairs + var mappers = mappersWithDiagnostics.SelectMany(static (x, _) => x.Mappers); + + context.RegisterImplementationSourceOutput( + mappers, + static (spc, source) => + { + var mapperText = source.Body.NormalizeWhitespace().ToFullString(); + spc.AddSource(source.FileName, SourceText.From(mapperText, Encoding.UTF8)); + } + ); } - private static void Execute(Compilation compilation, ImmutableArray mappers, SourceProductionContext ctx) + private static MapperResults BuildDescriptors( + Compilation compilation, + ImmutableArray mappers, + CancellationToken cancellationToken + ) { if (mappers.IsDefaultOrEmpty) - return; + return MapperResults.Empty; #if DEBUG_SOURCE_GENERATOR DebuggerUtil.AttachDebugger(); #endif var mapperAttributeSymbol = compilation.GetTypeByMetadataName(MapperAttributeName); if (mapperAttributeSymbol == null) - return; + return MapperResults.Empty; var wellKnownTypes = new WellKnownTypes(compilation); var uniqueNameBuilder = new UniqueNameBuilder(); + + var diagnostics = new List(); + var members = new List(); + foreach (var mapperSyntax in mappers.Distinct()) { + cancellationToken.ThrowIfCancellationRequested(); var mapperModel = compilation.GetSemanticModel(mapperSyntax.SyntaxTree); if (mapperModel.GetDeclaredSymbol(mapperSyntax) is not INamedTypeSymbol mapperSymbol) continue; @@ -48,13 +78,14 @@ private static void Execute(Compilation compilation, ImmutableArray +{ + public MapperNode(CompilationUnitSyntax body, string fileName) + { + Body = body; + FileName = fileName; + } + + public CompilationUnitSyntax Body { get; } + public string FileName { get; } + + public bool Equals(MapperNode other) => Body.IsEquivalentTo(other.Body) && FileName == other.FileName; + + public override bool Equals(object? obj) => obj is MapperNode other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + return (Body.GetHashCode() * 397) ^ FileName.GetHashCode(); + } + } +} diff --git a/src/Riok.Mapperly/MapperResults.cs b/src/Riok.Mapperly/MapperResults.cs new file mode 100644 index 0000000000..c9e99f6c14 --- /dev/null +++ b/src/Riok.Mapperly/MapperResults.cs @@ -0,0 +1,50 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Riok.Mapperly; + +public readonly struct MapperResults : IEquatable +{ + public static readonly MapperResults Empty = new(ImmutableArray.Empty, ImmutableArray.Empty); + + public MapperResults(ImmutableArray mappers, ImmutableArray diagnostics) + { + Mappers = mappers; + Diagnostics = diagnostics; + } + + public ImmutableArray Mappers { get; } + public ImmutableArray Diagnostics { get; } + + public bool Equals(MapperResults other) + { + return Mappers.SequenceEqual(other.Mappers) && Diagnostics.SequenceEqual(other.Diagnostics); + } + + public override bool Equals(object? obj) => obj is MapperResults other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = 0; + foreach (var mapper in Mappers) + { + hash = Combine(hash, mapper.GetHashCode()); + } + foreach (var diagnostic in Diagnostics) + { + hash = Combine(hash, diagnostic.GetHashCode()); + } + return hash; + } + + static int Combine(int h1, int h2) + { + // RyuJIT optimizes this to use the ROL instruction + // Related GitHub pull request: https://github.com/dotnet/coreclr/pull/1830 + uint rol5 = ((uint)h1 << 5) | ((uint)h1 >> 27); + return ((int)rol5 + h1) ^ h2; + } + } +} diff --git a/test/Riok.Mapperly.Tests/Generator/IncrementalGeneratorTest.cs b/test/Riok.Mapperly.Tests/Generator/IncrementalGeneratorTest.cs new file mode 100644 index 0000000000..7cc05b7644 --- /dev/null +++ b/test/Riok.Mapperly.Tests/Generator/IncrementalGeneratorTest.cs @@ -0,0 +1,123 @@ +#if NET7_0_OR_GREATER +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Riok.Mapperly.Tests.Generator; + +[UsesVerify] +public class IncrementalGeneratorTest +{ + private const string AddMappersStep = "ImplementationSourceOutput"; + private const string ReportDiagnosticsStep = "diagnostics"; + + [Fact] + public void AddingUnrelatedTypeDoesNotRegenerateOriginal() + { + var source = TestSourceBuilder.Mapping("string", "string"); + + var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default); + var compilation1 = TestHelper.BuildCompilation(TestHelperOptions.NoDiagnostics.NullableOption, syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + + var compilation2 = compilation1.AddSyntaxTrees(TestSourceBuilder.SyntaxTree("struct MyValue {}")); + var driver2 = driver1.RunGenerators(compilation2); + + AssertRunResults(AddMappersStep, driver2, IncrementalStepRunReason.Cached); + AssertRunResults(ReportDiagnosticsStep, driver2, IncrementalStepRunReason.Cached); + } + + [Fact] + public void AddingNewMapperDoesNotRegenerateOriginal() + { + var source = TestSourceBuilder.MapperWithBodyAndTypes( + "[MapperIgnoreSource(\"not_found\")] partial B Map(A source);", + "class A { }", + "class B { }" + ); + + var secondary = TestSourceBuilder.SyntaxTree( + """ +using Riok.Mapperly.Abstractions; + +namespace Test.B +{ + [Mapper] + internal partial class BarFooMapper + { + internal partial string BarToFoo(string value); + } +} +""" + ); + + var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default); + var compilation1 = TestHelper.BuildCompilation(TestHelperOptions.NoDiagnostics.NullableOption, syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + + var compilation2 = compilation1.AddSyntaxTrees(secondary); + var driver2 = driver1.RunGenerators(compilation2); + + AssertRunResults(AddMappersStep, driver2, IncrementalStepRunReason.Cached, IncrementalStepRunReason.New); + AssertRunResults(ReportDiagnosticsStep, driver2, IncrementalStepRunReason.Modified); + } + + [Fact] + public void AppendingUnrelatedTypeDoesNotRegenerateOriginal() + { + var source = TestSourceBuilder.Mapping("string", "string"); + var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default); + var compilation1 = TestHelper.BuildCompilation(TestHelperOptions.NoDiagnostics.NullableOption, syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + + var newTree = syntaxTree.WithRootAndOptions( + syntaxTree.GetCompilationUnitRoot().AddMembers(SyntaxFactory.ParseMemberDeclaration("struct Foo {}")!), + syntaxTree.Options + ); + + var compilation2 = compilation1.ReplaceSyntaxTree(compilation1.SyntaxTrees.First(), newTree); + var driver2 = driver1.RunGenerators(compilation2); + + AssertRunResults(AddMappersStep, driver2, IncrementalStepRunReason.Cached); + AssertRunResults(ReportDiagnosticsStep, driver2, IncrementalStepRunReason.Cached); + } + + [Fact] + public void ModifyingMapperDoesRegenerateOriginal() + { + var source = TestSourceBuilder.MapperWithBodyAndTypes( + "[MapperIgnoreSource(\"not_found\")] partial B Map(A source);", + "class A { }", + "class B { }" + ); + var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default); + var compilation1 = TestHelper.BuildCompilation(TestHelperOptions.NoDiagnostics.NullableOption, syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + + var classDeclaration = syntaxTree.GetCompilationUnitRoot().Members.First() as ClassDeclarationSyntax; + var member = SyntaxFactory.ParseMemberDeclaration("internal partial int BarToBaz(int value);")!; + var updatedClass = classDeclaration!.AddMembers(member); + + var newRoot = syntaxTree.GetCompilationUnitRoot().ReplaceNode(classDeclaration, updatedClass); + var newTree = syntaxTree.WithRootAndOptions(newRoot, syntaxTree.Options); + + var compilation2 = compilation1.ReplaceSyntaxTree(compilation1.SyntaxTrees.First(), newTree); + var driver2 = driver1.RunGenerators(compilation2); + + AssertRunResults(AddMappersStep, driver2, IncrementalStepRunReason.New); + AssertRunResults(ReportDiagnosticsStep, driver2, IncrementalStepRunReason.Modified); + } + + private static void AssertRunResults(string name, GeneratorDriver driver, params IncrementalStepRunReason[] runReasons) + { + var runResult = driver.GetRunResult().Results[0]; + + var step = runResult.TrackedSteps[name].SelectMany(x => x.Outputs); + step.Select(x => x.Reason).Should().BeEquivalentTo(runReasons, o => o.WithStrictOrdering()); + } +} +#endif diff --git a/test/Riok.Mapperly.Tests/TestHelper.cs b/test/Riok.Mapperly.Tests/TestHelper.cs index 4fcc845556..02a98f2e43 100644 --- a/test/Riok.Mapperly.Tests/TestHelper.cs +++ b/test/Riok.Mapperly.Tests/TestHelper.cs @@ -7,6 +7,9 @@ namespace Riok.Mapperly.Tests; public static class TestHelper { + private static readonly GeneratorDriverOptions _enableIncrementalTrackingDriverOptions = + new(IncrementalGeneratorOutputKind.None, trackIncrementalGeneratorSteps: true); + public static Task VerifyGenerator(string source, TestHelperOptions? options = null, params object?[] args) { var driver = Generate(source, options); @@ -50,19 +53,7 @@ public static MapperGenerationResult GenerateMapper(string source, TestHelperOpt return mapperResult; } - private static GeneratorDriver Generate(string source, TestHelperOptions? options) - { - options ??= TestHelperOptions.NoDiagnostics; - - var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(options.LanguageVersion)); - var compilation = BuildCompilation(options.NullableOption, syntaxTree); - var generator = new MapperGenerator(); - - GeneratorDriver driver = CSharpGeneratorDriver.Create(generator); - return driver.RunGenerators(compilation); - } - - private static CSharpCompilation BuildCompilation(NullableContextOptions nullableOption, params SyntaxTree[] syntaxTrees) + public static CSharpCompilation BuildCompilation(NullableContextOptions nullableOption, params SyntaxTree[] syntaxTrees) { var references = AppDomain.CurrentDomain .GetAssemblies() @@ -80,4 +71,27 @@ private static CSharpCompilation BuildCompilation(NullableContextOptions nullabl return CSharpCompilation.Create("Tests", syntaxTrees, references, compilationOptions); } + + public static GeneratorDriver GenerateTracked(Compilation compilation) + { + var generator = new MapperGenerator(); + + var driver = CSharpGeneratorDriver.Create( + new[] { generator.AsSourceGenerator() }, + driverOptions: _enableIncrementalTrackingDriverOptions + ); + return driver.RunGenerators(compilation); + } + + private static GeneratorDriver Generate(string source, TestHelperOptions? options) + { + options ??= TestHelperOptions.NoDiagnostics; + + var syntaxTree = CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(options.LanguageVersion)); + var compilation = BuildCompilation(options.NullableOption, syntaxTree); + var generator = new MapperGenerator(); + + GeneratorDriver driver = CSharpGeneratorDriver.Create(generator); + return driver.RunGenerators(compilation); + } } diff --git a/test/Riok.Mapperly.Tests/TestSourceBuilder.cs b/test/Riok.Mapperly.Tests/TestSourceBuilder.cs index bcb313570f..724cc92778 100644 --- a/test/Riok.Mapperly.Tests/TestSourceBuilder.cs +++ b/test/Riok.Mapperly.Tests/TestSourceBuilder.cs @@ -1,5 +1,7 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; namespace Riok.Mapperly.Tests; @@ -66,6 +68,11 @@ public static string MapperWithBodyAndTypes( return MapperWithBody(body, options) + sep + string.Join(sep, types); } + public static SyntaxTree SyntaxTree([StringSyntax(StringSyntax.CSharp)] string source) + { + return CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default); + } + private static string BuildAttribute(TestSourceBuilderOptions options) { var attrs = new[]