diff --git a/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs b/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs index 8a4630d601..082bdc93ee 100644 --- a/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs @@ -167,13 +167,10 @@ or CollectionType.ReadOnlyMemory if (typeInfo is not CollectionType.None) return "Count"; - var intType = types.Get(); var member = symbolAccessor .GetAllAccessibleMappableMembers(t) .FirstOrDefault( - x => - x.Name is nameof(ICollection.Count) or nameof(Array.Length) - && SymbolEqualityComparer.IncludeNullability.Equals(intType, x.Type) + x => x.Name is nameof(ICollection.Count) or nameof(Array.Length) && x.Type.SpecialType == SpecialType.System_Int32 ); return member?.Name; } diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs index b7f883e628..1d9d1af1c8 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Riok.Mapperly.Abstractions; using Riok.Mapperly.Descriptors.Enumerables; @@ -15,8 +14,9 @@ public static class DictionaryMappingBuilder private const string CountPropertyName = nameof(IDictionary.Count); private const string SetterIndexerPropertyName = "set_Item"; - private const string ToImmutableDictionaryMethodName = nameof(ImmutableDictionary.ToImmutableDictionary); - private const string ToImmutableSortedDictionaryMethodName = nameof(ImmutableSortedDictionary.ToImmutableSortedDictionary); + private const string ToImmutableDictionaryMethodName = "global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary"; + private const string ToImmutableSortedDictionaryMethodName = + "global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary"; public static ITypeMapping? TryBuildMapping(MappingBuilderContext ctx) { @@ -155,34 +155,21 @@ out var isExplicit return typedInter; } - private static LinqDicitonaryMapping? ResolveImmutableCollectMethod( + private static LinqDictionaryMapping? ResolveImmutableCollectMethod( MappingBuilderContext ctx, ITypeMapping keyMapping, ITypeMapping valueMapping ) { - if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableSortedDictionary<,>)))) - return new LinqDicitonaryMapping( - ctx.Source, - ctx.Target, - ctx.Types.Get(typeof(ImmutableSortedDictionary)).GetStaticGenericMethod(ToImmutableSortedDictionaryMethodName)!, - keyMapping, - valueMapping - ); - - // if target is an ImmutableDictionary or IImmutableDictionary - if ( - SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(IImmutableDictionary<,>))) - || SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableDictionary<,>))) - ) - return new LinqDicitonaryMapping( - ctx.Source, - ctx.Target, - ctx.Types.Get(typeof(ImmutableDictionary)).GetStaticGenericMethod(ToImmutableDictionaryMethodName)!, - keyMapping, - valueMapping - ); - - return null; + return ctx.CollectionInfos!.Target.CollectionType switch + { + CollectionType.ImmutableSortedDictionary + => new LinqDictionaryMapping(ctx.Source, ctx.Target, ToImmutableSortedDictionaryMethodName, keyMapping, valueMapping), + CollectionType.ImmutableDictionary + or CollectionType.IImmutableDictionary + => new LinqDictionaryMapping(ctx.Source, ctx.Target, ToImmutableDictionaryMethodName, keyMapping, valueMapping), + + _ => null, + }; } } diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs index e16b955dba..be110137c0 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Riok.Mapperly.Abstractions; using Riok.Mapperly.Descriptors.Enumerables; @@ -6,24 +5,25 @@ using Riok.Mapperly.Descriptors.Mappings; using Riok.Mapperly.Descriptors.Mappings.ExistingTarget; using Riok.Mapperly.Diagnostics; +using Riok.Mapperly.Emit; using Riok.Mapperly.Helpers; namespace Riok.Mapperly.Descriptors.MappingBuilders; public static class EnumerableMappingBuilder { - private const string SelectMethodName = nameof(Enumerable.Select); - private const string ToArrayMethodName = nameof(Enumerable.ToArray); - private const string ToListMethodName = nameof(Enumerable.ToList); + private const string SelectMethodName = "global::System.Linq.Enumerable.Select"; + private const string ToArrayMethodName = "global::System.Linq.Enumerable.ToArray"; + private const string ToListMethodName = "global::System.Linq.Enumerable.ToList"; private const string ToHashSetMethodName = "ToHashSet"; private const string AddMethodName = nameof(ICollection.Add); - private const string ToImmutableArrayMethodName = nameof(ImmutableArray.ToImmutableArray); - private const string ToImmutableListMethodName = nameof(ImmutableList.ToImmutableList); - private const string ToImmutableHashSetMethodName = nameof(ImmutableHashSet.ToImmutableHashSet); - private const string CreateRangeQueueMethodName = nameof(ImmutableQueue.CreateRange); - private const string CreateRangeStackMethodName = nameof(ImmutableStack.CreateRange); - private const string ToImmutableSortedSetMethodName = nameof(ImmutableSortedSet.ToImmutableSortedSet); + private const string ToImmutableArrayMethodName = "global::System.Collections.Immutable.ImmutableArray.ToImmutableArray"; + private const string ToImmutableListMethodName = "global::System.Collections.Immutable.ImmutableList.ToImmutableList"; + private const string ToImmutableHashSetMethodName = "global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet"; + private const string CreateRangeQueueMethodName = "global::System.Collections.Immutable.ImmutableQueue.CreateRange"; + private const string CreateRangeStackMethodName = "global::System.Collections.Immutable.ImmutableStack.CreateRange"; + private const string ToImmutableSortedSetMethodName = "global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet"; public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx) { @@ -127,10 +127,9 @@ ForEachAddEnumerableExistingTargetMapping CreateForEach(string methodName) } } - private static LinqEnumerableMapping BuildLinqMapping(MappingBuilderContext ctx, ITypeMapping elementMapping, string? collectMethodName) + private static LinqEnumerableMapping BuildLinqMapping(MappingBuilderContext ctx, ITypeMapping elementMapping, string? collectMethod) { - var collectMethod = collectMethodName == null ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(collectMethodName); - var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName); + var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName; return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod); } @@ -158,7 +157,7 @@ private static LinqConstructorMapping BuildLinqConstructorMapping( ITypeMapping elementMapping ) { - var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName); + var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName; return new LinqConstructorMapping(ctx.Source, ctx.Target, targetTypeToConstruct, elementMapping, selectMethod); } @@ -219,7 +218,7 @@ ctx.CollectionInfos.Target.CollectionType is CollectionType.ISet or CollectionTy && GetToHashSetLinqCollectMethod(ctx.Types) is { } toHashSetMethod ) { - return (true, toHashSetMethod.Name); + return (true, SyntaxFactoryHelper.StaticMethodString(toHashSetMethod)); } // if target is a IReadOnlyCollection, IEnumerable, IList, List or ICollection with ToList() @@ -241,31 +240,22 @@ or CollectionType.ICollection if (collectMethod is null) return null; - var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName); + var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName; return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod); } - private static IMethodSymbol? ResolveImmutableCollectMethod(MappingBuilderContext ctx) + private static string? ResolveImmutableCollectMethod(MappingBuilderContext ctx) { - if (ctx.CollectionInfos!.Target.CollectionType == CollectionType.ImmutableArray) - return ctx.Types.Get(typeof(ImmutableArray)).GetStaticGenericMethod(ToImmutableArrayMethodName); - - if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableList or CollectionType.IImmutableList) - return ctx.Types.Get(typeof(ImmutableList)).GetStaticGenericMethod(ToImmutableListMethodName); - - if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableHashSet or CollectionType.IImmutableSet) - return ctx.Types.Get(typeof(ImmutableHashSet)).GetStaticGenericMethod(ToImmutableHashSetMethodName); - - if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableQueue or CollectionType.IImmutableQueue) - return ctx.Types.Get(typeof(ImmutableQueue)).GetStaticGenericMethod(CreateRangeQueueMethodName); - - if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableStack or CollectionType.IImmutableStack) - return ctx.Types.Get(typeof(ImmutableStack)).GetStaticGenericMethod(CreateRangeStackMethodName); - - if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableSortedSet) - return ctx.Types.Get(typeof(ImmutableSortedSet)).GetStaticGenericMethod(ToImmutableSortedSetMethodName); - - return null; + return ctx.CollectionInfos!.Target.CollectionType switch + { + CollectionType.ImmutableArray => ToImmutableArrayMethodName, + CollectionType.ImmutableList or CollectionType.IImmutableList => ToImmutableListMethodName, + CollectionType.ImmutableHashSet or CollectionType.IImmutableSet => ToImmutableHashSetMethodName, + CollectionType.ImmutableQueue or CollectionType.IImmutableQueue => CreateRangeQueueMethodName, + CollectionType.ImmutableStack or CollectionType.IImmutableStack => CreateRangeStackMethodName, + CollectionType.ImmutableSortedSet => ToImmutableSortedSetMethodName, + _ => null + }; } private static IMethodSymbol? GetToHashSetLinqCollectMethod(WellKnownTypes wellKnownTypes) => diff --git a/src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs b/src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs index 3de8e27f8d..8bb49ab350 100644 --- a/src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs +++ b/src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs @@ -12,14 +12,14 @@ public class LinqConstructorMapping : TypeMapping { private readonly INamedTypeSymbol _targetTypeToConstruct; private readonly ITypeMapping _elementMapping; - private readonly IMethodSymbol? _selectMethod; + private readonly string? _selectMethod; public LinqConstructorMapping( ITypeSymbol sourceType, ITypeSymbol targetType, INamedTypeSymbol targetTypeToConstruct, ITypeMapping elementMapping, - IMethodSymbol? selectMethod + string? selectMethod ) : base(sourceType, targetType) { @@ -38,7 +38,7 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx) var (lambdaCtx, lambdaSourceName) = ctx.WithNewScopedSource(); var sourceMapExpression = _elementMapping.Build(lambdaCtx); var convertLambda = SimpleLambdaExpression(Parameter(Identifier(lambdaSourceName))).WithExpressionBody(sourceMapExpression); - mappedSource = StaticInvocation(_selectMethod, ctx.Source, convertLambda); + mappedSource = Invocation(_selectMethod, ctx.Source, convertLambda); } else { diff --git a/src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs b/src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs index 240c077d6b..1ab9d4ea2b 100644 --- a/src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs +++ b/src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs @@ -8,19 +8,19 @@ namespace Riok.Mapperly.Descriptors.Mappings; /// /// Represents an enumerable mapping which works by using linq (select + collect). /// -public class LinqDicitonaryMapping : TypeMapping +public class LinqDictionaryMapping : TypeMapping { private const string KeyPropertyName = nameof(KeyValuePair.Key); private const string ValuePropertyName = nameof(KeyValuePair.Value); - private readonly IMethodSymbol _collectMethod; + private readonly string _collectMethod; private readonly ITypeMapping _keyMapping; private readonly ITypeMapping _valueMapping; - public LinqDicitonaryMapping( + public LinqDictionaryMapping( ITypeSymbol sourceType, ITypeSymbol targetType, - IMethodSymbol collectMethod, + string collectMethod, ITypeMapping keyMapping, ITypeMapping valueMapping ) @@ -36,7 +36,7 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx) // if key and value types do not change then use a simple call // ie: source.ToImmutableDictionary(); if (_keyMapping.IsSynthetic && _valueMapping.IsSynthetic) - return StaticInvocation(_collectMethod, ctx.Source); + return Invocation(_collectMethod, ctx.Source); // create expressions mapping the key and value and then create the final expression // ie: source.ToImmutableDictionary(x => x.Key, x => (int)x.Value); @@ -48,6 +48,6 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx) var valueMapExpression = _valueMapping.Build(valueLambdaCtx); var valueExpression = SimpleLambdaExpression(Parameter(Identifier(valueLambdaParamName))).WithExpressionBody(valueMapExpression); - return StaticInvocation(_collectMethod, ctx.Source, keyExpression, valueExpression); + return Invocation(_collectMethod, ctx.Source, keyExpression, valueExpression); } } diff --git a/src/Riok.Mapperly/Descriptors/Mappings/LinqEnumerableMapping.cs b/src/Riok.Mapperly/Descriptors/Mappings/LinqEnumerableMapping.cs index 4e93d99f3f..2a97e33456 100644 --- a/src/Riok.Mapperly/Descriptors/Mappings/LinqEnumerableMapping.cs +++ b/src/Riok.Mapperly/Descriptors/Mappings/LinqEnumerableMapping.cs @@ -11,15 +11,15 @@ namespace Riok.Mapperly.Descriptors.Mappings; public class LinqEnumerableMapping : TypeMapping { private readonly ITypeMapping _elementMapping; - private readonly IMethodSymbol? _selectMethod; - private readonly IMethodSymbol? _collectMethod; + private readonly string? _selectMethod; + private readonly string? _collectMethod; public LinqEnumerableMapping( ITypeSymbol sourceType, ITypeSymbol targetType, ITypeMapping elementMapping, - IMethodSymbol? selectMethod, - IMethodSymbol? collectMethod + string? selectMethod, + string? collectMethod ) : base(sourceType, targetType) { @@ -38,13 +38,13 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx) var (lambdaCtx, lambdaSourceName) = ctx.WithNewScopedSource(); var sourceMapExpression = _elementMapping.Build(lambdaCtx); var convertLambda = SimpleLambdaExpression(Parameter(Identifier(lambdaSourceName))).WithExpressionBody(sourceMapExpression); - mappedSource = StaticInvocation(_selectMethod, ctx.Source, convertLambda); + mappedSource = Invocation(_selectMethod, ctx.Source, convertLambda); } else { mappedSource = _elementMapping.Build(ctx); } - return _collectMethod == null ? mappedSource : StaticInvocation(_collectMethod, mappedSource); + return _collectMethod == null ? mappedSource : Invocation(_collectMethod, mappedSource); } } diff --git a/src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs b/src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs index 5ebf8cc2ea..5246b7d294 100644 --- a/src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs +++ b/src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs @@ -299,21 +299,26 @@ public static InvocationExpressionSyntax StaticInvocation(string receiverType, s return InvocationExpression(methodAccess).WithArgumentList(ArgumentList(arguments)); } - public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ExpressionSyntax[] arguments) => - StaticInvocation( - FullyQualifiedIdentifierName(method.ReceiverType?.NonNullable()!) - ?? throw new ArgumentNullException(nameof(method.ReceiverType)), - method.Name, - arguments - ); + public static string StaticMethodString(IMethodSymbol method) + { + var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null"); + var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable()); + return $"{qualifiedReceiverName}.{method.Name}"; + } + + public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ExpressionSyntax[] arguments) + { + var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null"); + var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable()); + return StaticInvocation(qualifiedReceiverName, method.Name, arguments); + } public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ArgumentSyntax[] arguments) { - var receiverType = - FullyQualifiedIdentifierName(method.ReceiverType?.NonNullable()!) - ?? throw new ArgumentNullException(nameof(method.ReceiverType)); + var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null"); + var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable()); - var receiverTypeIdentifier = IdentifierName(receiverType); + var receiverTypeIdentifier = IdentifierName(qualifiedReceiverName); var methodAccess = MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, receiverTypeIdentifier,