Skip to content

Commit

Permalink
chore: remove IMethodSymbol from Dictionary and Enumerable and use …
Browse files Browse the repository at this point in the history
…strings
  • Loading branch information
TimothyMakkison committed Jul 11, 2023
1 parent ff51d7d commit 7e33cc1
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,10 @@ or CollectionType.ReadOnlyMemory
if (typeInfo is not CollectionType.None)
return "Count";

var intType = types.Get<int>();
var member = symbolAccessor
.GetAllAccessibleMappableMembers(t)
.FirstOrDefault(
x =>
x.Name is nameof(ICollection<object>.Count) or nameof(Array.Length)
&& SymbolEqualityComparer.IncludeNullability.Equals(intType, x.Type)
x => x.Name is nameof(ICollection<object>.Count) or nameof(Array.Length) && x.Type.SpecialType == SpecialType.System_Int32
);
return member?.Name;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.Enumerables;
Expand All @@ -15,8 +14,9 @@ public static class DictionaryMappingBuilder
private const string CountPropertyName = nameof(IDictionary<object, object>.Count);
private const string SetterIndexerPropertyName = "set_Item";

private const string ToImmutableDictionaryMethodName = nameof(ImmutableDictionary.ToImmutableDictionary);
private const string ToImmutableSortedDictionaryMethodName = nameof(ImmutableSortedDictionary.ToImmutableSortedDictionary);
private const string ToImmutableDictionaryMethodName = "global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary";
private const string ToImmutableSortedDictionaryMethodName =
"global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary";

public static ITypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
Expand Down Expand Up @@ -155,34 +155,21 @@ out var isExplicit
return typedInter;
}

private static LinqDicitonaryMapping? ResolveImmutableCollectMethod(
private static LinqDictionaryMapping? ResolveImmutableCollectMethod(
MappingBuilderContext ctx,
ITypeMapping keyMapping,
ITypeMapping valueMapping
)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableSortedDictionary<,>))))
return new LinqDicitonaryMapping(
ctx.Source,
ctx.Target,
ctx.Types.Get(typeof(ImmutableSortedDictionary)).GetStaticGenericMethod(ToImmutableSortedDictionaryMethodName)!,
keyMapping,
valueMapping
);

// if target is an ImmutableDictionary or IImmutableDictionary
if (
SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(IImmutableDictionary<,>)))
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableDictionary<,>)))
)
return new LinqDicitonaryMapping(
ctx.Source,
ctx.Target,
ctx.Types.Get(typeof(ImmutableDictionary)).GetStaticGenericMethod(ToImmutableDictionaryMethodName)!,
keyMapping,
valueMapping
);

return null;
return ctx.CollectionInfos!.Target.CollectionType switch
{
CollectionType.ImmutableSortedDictionary
=> new LinqDictionaryMapping(ctx.Source, ctx.Target, ToImmutableSortedDictionaryMethodName, keyMapping, valueMapping),
CollectionType.ImmutableDictionary
or CollectionType.IImmutableDictionary
=> new LinqDictionaryMapping(ctx.Source, ctx.Target, ToImmutableDictionaryMethodName, keyMapping, valueMapping),

_ => null,
};
}
}
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.Enumerables;
using Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity;
using Riok.Mapperly.Descriptors.Mappings;
using Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
using Riok.Mapperly.Diagnostics;
using Riok.Mapperly.Emit;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.MappingBuilders;

public static class EnumerableMappingBuilder
{
private const string SelectMethodName = nameof(Enumerable.Select);
private const string ToArrayMethodName = nameof(Enumerable.ToArray);
private const string ToListMethodName = nameof(Enumerable.ToList);
private const string SelectMethodName = "global::System.Linq.Enumerable.Select";
private const string ToArrayMethodName = "global::System.Linq.Enumerable.ToArray";
private const string ToListMethodName = "global::System.Linq.Enumerable.ToList";
private const string ToHashSetMethodName = "ToHashSet";
private const string AddMethodName = nameof(ICollection<object>.Add);

private const string ToImmutableArrayMethodName = nameof(ImmutableArray.ToImmutableArray);
private const string ToImmutableListMethodName = nameof(ImmutableList.ToImmutableList);
private const string ToImmutableHashSetMethodName = nameof(ImmutableHashSet.ToImmutableHashSet);
private const string CreateRangeQueueMethodName = nameof(ImmutableQueue.CreateRange);
private const string CreateRangeStackMethodName = nameof(ImmutableStack.CreateRange);
private const string ToImmutableSortedSetMethodName = nameof(ImmutableSortedSet.ToImmutableSortedSet);
private const string ToImmutableArrayMethodName = "global::System.Collections.Immutable.ImmutableArray.ToImmutableArray";
private const string ToImmutableListMethodName = "global::System.Collections.Immutable.ImmutableList.ToImmutableList";
private const string ToImmutableHashSetMethodName = "global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet";
private const string CreateRangeQueueMethodName = "global::System.Collections.Immutable.ImmutableQueue.CreateRange";
private const string CreateRangeStackMethodName = "global::System.Collections.Immutable.ImmutableStack.CreateRange";
private const string ToImmutableSortedSetMethodName = "global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet";

public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
Expand Down Expand Up @@ -127,10 +127,9 @@ ForEachAddEnumerableExistingTargetMapping CreateForEach(string methodName)
}
}

private static LinqEnumerableMapping BuildLinqMapping(MappingBuilderContext ctx, ITypeMapping elementMapping, string? collectMethodName)
private static LinqEnumerableMapping BuildLinqMapping(MappingBuilderContext ctx, ITypeMapping elementMapping, string? collectMethod)
{
var collectMethod = collectMethodName == null ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(collectMethodName);
var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName);
var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName;
return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

Expand Down Expand Up @@ -158,7 +157,7 @@ private static LinqConstructorMapping BuildLinqConstructorMapping(
ITypeMapping elementMapping
)
{
var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName);
var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName;
return new LinqConstructorMapping(ctx.Source, ctx.Target, targetTypeToConstruct, elementMapping, selectMethod);
}

Expand Down Expand Up @@ -219,7 +218,7 @@ ctx.CollectionInfos.Target.CollectionType is CollectionType.ISet or CollectionTy
&& GetToHashSetLinqCollectMethod(ctx.Types) is { } toHashSetMethod
)
{
return (true, toHashSetMethod.Name);
return (true, SyntaxFactoryHelper.StaticMethodString(toHashSetMethod));
}

// if target is a IReadOnlyCollection<T>, IEnumerable<T>, IList<T>, List<T> or ICollection<T> with ToList()
Expand All @@ -241,31 +240,22 @@ or CollectionType.ICollection
if (collectMethod is null)
return null;

var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName);
var selectMethod = elementMapping.IsSynthetic ? null : SelectMethodName;
return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

private static IMethodSymbol? ResolveImmutableCollectMethod(MappingBuilderContext ctx)
private static string? ResolveImmutableCollectMethod(MappingBuilderContext ctx)
{
if (ctx.CollectionInfos!.Target.CollectionType == CollectionType.ImmutableArray)
return ctx.Types.Get(typeof(ImmutableArray)).GetStaticGenericMethod(ToImmutableArrayMethodName);

if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableList or CollectionType.IImmutableList)
return ctx.Types.Get(typeof(ImmutableList)).GetStaticGenericMethod(ToImmutableListMethodName);

if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableHashSet or CollectionType.IImmutableSet)
return ctx.Types.Get(typeof(ImmutableHashSet)).GetStaticGenericMethod(ToImmutableHashSetMethodName);

if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableQueue or CollectionType.IImmutableQueue)
return ctx.Types.Get(typeof(ImmutableQueue)).GetStaticGenericMethod(CreateRangeQueueMethodName);

if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableStack or CollectionType.IImmutableStack)
return ctx.Types.Get(typeof(ImmutableStack)).GetStaticGenericMethod(CreateRangeStackMethodName);

if (ctx.CollectionInfos.Target.CollectionType is CollectionType.ImmutableSortedSet)
return ctx.Types.Get(typeof(ImmutableSortedSet)).GetStaticGenericMethod(ToImmutableSortedSetMethodName);

return null;
return ctx.CollectionInfos!.Target.CollectionType switch
{
CollectionType.ImmutableArray => ToImmutableArrayMethodName,
CollectionType.ImmutableList or CollectionType.IImmutableList => ToImmutableListMethodName,
CollectionType.ImmutableHashSet or CollectionType.IImmutableSet => ToImmutableHashSetMethodName,
CollectionType.ImmutableQueue or CollectionType.IImmutableQueue => CreateRangeQueueMethodName,
CollectionType.ImmutableStack or CollectionType.IImmutableStack => CreateRangeStackMethodName,
CollectionType.ImmutableSortedSet => ToImmutableSortedSetMethodName,
_ => null
};
}

private static IMethodSymbol? GetToHashSetLinqCollectMethod(WellKnownTypes wellKnownTypes) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ public class LinqConstructorMapping : TypeMapping
{
private readonly INamedTypeSymbol _targetTypeToConstruct;
private readonly ITypeMapping _elementMapping;
private readonly IMethodSymbol? _selectMethod;
private readonly string? _selectMethod;

public LinqConstructorMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
INamedTypeSymbol targetTypeToConstruct,
ITypeMapping elementMapping,
IMethodSymbol? selectMethod
string? selectMethod
)
: base(sourceType, targetType)
{
Expand All @@ -38,7 +38,7 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
var (lambdaCtx, lambdaSourceName) = ctx.WithNewScopedSource();
var sourceMapExpression = _elementMapping.Build(lambdaCtx);
var convertLambda = SimpleLambdaExpression(Parameter(Identifier(lambdaSourceName))).WithExpressionBody(sourceMapExpression);
mappedSource = StaticInvocation(_selectMethod, ctx.Source, convertLambda);
mappedSource = Invocation(_selectMethod, ctx.Source, convertLambda);
}
else
{
Expand Down
12 changes: 6 additions & 6 deletions src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ namespace Riok.Mapperly.Descriptors.Mappings;
/// <summary>
/// Represents an enumerable mapping which works by using linq (select + collect).
/// </summary>
public class LinqDicitonaryMapping : TypeMapping
public class LinqDictionaryMapping : TypeMapping
{
private const string KeyPropertyName = nameof(KeyValuePair<object, object>.Key);
private const string ValuePropertyName = nameof(KeyValuePair<object, object>.Value);

private readonly IMethodSymbol _collectMethod;
private readonly string _collectMethod;
private readonly ITypeMapping _keyMapping;
private readonly ITypeMapping _valueMapping;

public LinqDicitonaryMapping(
public LinqDictionaryMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
IMethodSymbol collectMethod,
string collectMethod,
ITypeMapping keyMapping,
ITypeMapping valueMapping
)
Expand All @@ -36,7 +36,7 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
// if key and value types do not change then use a simple call
// ie: source.ToImmutableDictionary();
if (_keyMapping.IsSynthetic && _valueMapping.IsSynthetic)
return StaticInvocation(_collectMethod, ctx.Source);
return Invocation(_collectMethod, ctx.Source);

// create expressions mapping the key and value and then create the final expression
// ie: source.ToImmutableDictionary(x => x.Key, x => (int)x.Value);
Expand All @@ -48,6 +48,6 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
var valueMapExpression = _valueMapping.Build(valueLambdaCtx);
var valueExpression = SimpleLambdaExpression(Parameter(Identifier(valueLambdaParamName))).WithExpressionBody(valueMapExpression);

return StaticInvocation(_collectMethod, ctx.Source, keyExpression, valueExpression);
return Invocation(_collectMethod, ctx.Source, keyExpression, valueExpression);
}
}
12 changes: 6 additions & 6 deletions src/Riok.Mapperly/Descriptors/Mappings/LinqEnumerableMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ namespace Riok.Mapperly.Descriptors.Mappings;
public class LinqEnumerableMapping : TypeMapping
{
private readonly ITypeMapping _elementMapping;
private readonly IMethodSymbol? _selectMethod;
private readonly IMethodSymbol? _collectMethod;
private readonly string? _selectMethod;
private readonly string? _collectMethod;

public LinqEnumerableMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
ITypeMapping elementMapping,
IMethodSymbol? selectMethod,
IMethodSymbol? collectMethod
string? selectMethod,
string? collectMethod
)
: base(sourceType, targetType)
{
Expand All @@ -38,13 +38,13 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
var (lambdaCtx, lambdaSourceName) = ctx.WithNewScopedSource();
var sourceMapExpression = _elementMapping.Build(lambdaCtx);
var convertLambda = SimpleLambdaExpression(Parameter(Identifier(lambdaSourceName))).WithExpressionBody(sourceMapExpression);
mappedSource = StaticInvocation(_selectMethod, ctx.Source, convertLambda);
mappedSource = Invocation(_selectMethod, ctx.Source, convertLambda);
}
else
{
mappedSource = _elementMapping.Build(ctx);
}

return _collectMethod == null ? mappedSource : StaticInvocation(_collectMethod, mappedSource);
return _collectMethod == null ? mappedSource : Invocation(_collectMethod, mappedSource);
}
}
27 changes: 16 additions & 11 deletions src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,26 @@ public static InvocationExpressionSyntax StaticInvocation(string receiverType, s
return InvocationExpression(methodAccess).WithArgumentList(ArgumentList(arguments));
}

public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ExpressionSyntax[] arguments) =>
StaticInvocation(
FullyQualifiedIdentifierName(method.ReceiverType?.NonNullable()!)
?? throw new ArgumentNullException(nameof(method.ReceiverType)),
method.Name,
arguments
);
public static string StaticMethodString(IMethodSymbol method)
{
var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null");
var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable());
return $"{qualifiedReceiverName}.{method.Name}";
}

public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ExpressionSyntax[] arguments)
{
var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null");
var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable());
return StaticInvocation(qualifiedReceiverName, method.Name, arguments);
}

public static InvocationExpressionSyntax StaticInvocation(IMethodSymbol method, params ArgumentSyntax[] arguments)
{
var receiverType =
FullyQualifiedIdentifierName(method.ReceiverType?.NonNullable()!)
?? throw new ArgumentNullException(nameof(method.ReceiverType));
var receiver = method.ReceiverType ?? throw new NullReferenceException(nameof(method.ReceiverType) + " is null");
var qualifiedReceiverName = FullyQualifiedIdentifierName(receiver.NonNullable());

var receiverTypeIdentifier = IdentifierName(receiverType);
var receiverTypeIdentifier = IdentifierName(qualifiedReceiverName);
var methodAccess = MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
receiverTypeIdentifier,
Expand Down

0 comments on commit 7e33cc1

Please sign in to comment.