Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing generic constraints to generated methods #1216

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Emit;
using Riok.Mapperly.Emit.Syntax;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
Expand Down Expand Up @@ -35,7 +37,9 @@ ITypeSymbol objectType
public override MethodDeclarationSyntax BuildMethod(SourceEmitterContext ctx)
{
var methodSyntax = (MethodDeclarationSyntax)Method.DeclaringSyntaxReferences.First().GetSyntax();
return base.BuildMethod(ctx).WithTypeParameterList(methodSyntax.TypeParameterList);
return base.BuildMethod(ctx)
.WithTypeParameterList(methodSyntax.TypeParameterList)
.WithConstraintClauses(List(GetTypeParameterConstraintClauses()));
latonz marked this conversation as resolved.
Show resolved Hide resolved
}

protected override ExpressionSyntax BuildTargetType()
Expand All @@ -44,6 +48,52 @@ protected override ExpressionSyntax BuildTargetType()
return TypeOfExpression(FullyQualifiedIdentifier(Method.ReturnType.NonNullable()));
}

protected virtual IEnumerable<TypeParameterConstraintClauseSyntax> GetTypeParameterConstraintClauses()
{
foreach (var tp in Method.TypeParameters)
{
var constraints = new List<TypeParameterConstraintSyntax>();

if (tp.HasUnmanagedTypeConstraint)
{
constraints.Add(TypeConstraint(IdentifierName("unmanaged")).AddLeadingSpace());
}
else if (tp.HasValueTypeConstraint)
{
constraints.Add(ClassOrStructConstraint(SyntaxKind.StructConstraint).AddLeadingSpace());
}
else if (tp.HasNotNullConstraint)
{
constraints.Add(TypeConstraint(IdentifierName("notnull")).AddLeadingSpace());
}
else if (tp.HasReferenceTypeConstraint)
{
constraints.Add(ClassOrStructConstraint(SyntaxKind.ClassConstraint).AddLeadingSpace());
}

foreach (var c in tp.ConstraintTypes)
{
constraints.Add(TypeConstraint(FullyQualifiedIdentifier(c)).AddLeadingSpace());
}

if (tp.HasConstructorConstraint)
{
constraints.Add(ConstructorConstraint().AddLeadingSpace());
}

if (!constraints.Any())
{
continue;
}

yield return TypeParameterConstraintClause(
IdentifierName(tp.Name).AddLeadingSpace().AddTrailingSpace(),
SeparatedList(constraints)
)
.AddLeadingSpace();
}
}

protected override ExpressionSyntax? BuildSwitchArmWhenClause(ExpressionSyntax targetType, RuntimeTargetTypeMapping mapping)
{
return mapping.IsAssignableToMethodTargetType ? null : base.BuildSwitchArmWhenClause(targetType, mapping);
Expand Down
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Emit/Syntax/SyntaxIndentationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public static TSyntax AddTrailingLineFeed<TSyntax>(this TSyntax syntax, int inde
return syntax.WithTrailingTrivia(trivia);
}

public static TSyntax AddLeadingSpace<TSyntax>(this TSyntax syntax)
where TSyntax : SyntaxNode => syntax.WithLeadingTrivia(syntax.GetLeadingTrivia().Add(ElasticSpace));

public static TSyntax AddTrailingSpace<TSyntax>(this TSyntax syntax)
where TSyntax : SyntaxNode => syntax.WithTrailingTrivia(syntax.GetTrailingTrivia().Add(ElasticSpace));

Expand Down
12 changes: 12 additions & 0 deletions test/Riok.Mapperly.Tests/GeneratedMethod.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Riok.Mapperly.Tests;
Expand All @@ -8,15 +9,26 @@ public GeneratedMethod(MethodDeclarationSyntax declarationSyntax)
{
Name = declarationSyntax.Identifier.ToString();
Signature = $"{declarationSyntax.ReturnType.ToString()} {Name}{declarationSyntax.ParameterList.ToString().Trim()}";
ConstraintClauses = ExtractParameterConstraints(declarationSyntax.ConstraintClauses);
Body = ExtractBody(declarationSyntax);
}

public string Name { get; }

public string Signature { get; }

public string? ConstraintClauses { get; }
latonz marked this conversation as resolved.
Show resolved Hide resolved

public string Body { get; }

private static string? ExtractParameterConstraints(SyntaxList<TypeParameterConstraintClauseSyntax> typeParameterConstraints)
{
if (typeParameterConstraints.Count == 0)
return null;

return typeParameterConstraints.ToFullString().Trim(' ', '\r', '\n').ReplaceLineEndings();
}

/// <summary>
/// Builds the method body without the method body braces and without the method body level indentation.
/// </summary>
Expand Down
13 changes: 13 additions & 0 deletions test/Riok.Mapperly.Tests/MapperGenerationResultAssertions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ public MapperGenerationResultAssertions HaveMethodBody(string methodName, [Strin
public MapperGenerationResultAssertions HaveMapMethodBody([StringSyntax(StringSyntax.CSharp)] string mapperMethodBody) =>
HaveMethodBody(TestSourceBuilder.DefaultMapMethodName, mapperMethodBody);

public MapperGenerationResultAssertions HaveMapMethodWithGenericConstraints(
string methodName,
[StringSyntax(StringSyntax.CSharp)] string? constraintClauses
)
{
_mapper.Methods[methodName].ConstraintClauses.Should().Be(constraintClauses);
return this;
}

public MapperGenerationResultAssertions HaveMapMethodWithGenericConstraints(
[StringSyntax(StringSyntax.CSharp)] string? constraintClauses
) => HaveMapMethodWithGenericConstraints(TestSourceBuilder.DefaultMapMethodName, constraintClauses);

private IReadOnlyCollection<Diagnostic> GetDiagnostics(DiagnosticDescriptor descriptor)
{
if (_mapper.DiagnosticsByDescriptorId.TryGetValue(descriptor.Id, out var diagnostics))
Expand Down
73 changes: 54 additions & 19 deletions test/Riok.Mapperly.Tests/Mapping/GenericTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ public void WithGenericSource()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand All @@ -135,11 +136,14 @@ partial object Map<TSource>(TSource source)

partial B MapToB(A source);
partial D MapToD(C source);
partial F MapToF(E source);
""",
"record struct A(string Value);",
"record A(string Value);",
"record struct B(string Value);",
"record C(string Value1);",
"record D(string Value1);"
"record D(string Value1);",
"record E(string Value) : A(Value);",
"record struct F(string Value) : B(Value);"
);
TestHelper
.GenerateMapper(source)
Expand All @@ -148,13 +152,13 @@ partial object Map<TSource>(TSource source)
"""
return source switch
{
global::E x => MapToF(x),
global::A x => MapToB(x),
global::C x => MapToD(x),
null => throw new System.ArgumentNullException(nameof(source)),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : global::A");
}

[Fact]
Expand Down Expand Up @@ -185,7 +189,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : notnull");
}

[Fact]
Expand Down Expand Up @@ -215,7 +220,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : struct");
}

[Fact]
Expand Down Expand Up @@ -245,7 +251,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : unmanaged");
}

[Fact]
Expand Down Expand Up @@ -275,7 +282,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -306,7 +314,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -337,7 +346,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -369,7 +379,8 @@ partial TTarget Map<TSource, TTarget>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class where TTarget : class");
}

[Fact]
Expand Down Expand Up @@ -404,7 +415,8 @@ public void WithGenericSourceSpecificTarget()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(global::BaseDto)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -434,7 +446,8 @@ public void WithGenericTarget()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -464,7 +477,8 @@ partial TTarget Map<TTarget>(object source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TTarget : global::D");
}

[Fact]
Expand Down Expand Up @@ -498,7 +512,8 @@ public void WithGenericTargetSpecificSource()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -529,7 +544,8 @@ partial TTarget Map<TSource, TTarget>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : global::C where TTarget : global::D");
}

[Fact]
Expand Down Expand Up @@ -562,7 +578,26 @@ public void WithUserImplementedMethodsShouldBeIncluded()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
public Task WithGenericConstructorConstraint()
{
var source = TestSourceBuilder.MapperWithBodyAndTypes(
"""
private partial TTarget Map<TSource, TTarget>(TSource source) where TSource : new() where TTarget : new();

private partial B MapToB(A source);
private partial D MapToD(C source);
""",
"record struct A(string Value) { public A() : this(default!) {} }",
"record struct B(string Value) { public B() : this(default!) {} }",
"record C(string Value1);",
"record D(string Value1);"
);
return TestHelper.VerifyGenerator(source);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//HintName: Mapper.g.cs
// <auto-generated />
#nullable enable
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial TTarget Map<TSource, TTarget>(TSource source) where TSource : new() where TTarget : new()
{
return source switch
{
global::A x when typeof(TTarget).IsAssignableFrom(typeof(global::B)) => (TTarget)(object)MapToB(x),
null => throw new System.ArgumentNullException(nameof(source)),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::B MapToB(global::A source)
{
var target = new global::B();
target.Value = source.Value;
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::D MapToD(global::C source)
{
var target = new global::D(source.Value1);
return target;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial TTarget Map<TSource, TTarget, TSource2, TTarget2>(TSource source)
private partial TTarget Map<TSource, TTarget, TSource2, TTarget2>(TSource source) where TSource : global::System.Linq.IQueryable<TSource2> where TTarget : global::System.Linq.IQueryable<TTarget2>
{
return source switch
{
Expand Down
Loading