Skip to content

Commit

Permalink
Add missing generic constraints to generated methods (#1216)
Browse files Browse the repository at this point in the history
  • Loading branch information
trejjam authored Apr 8, 2024
1 parent 1e78844 commit d6cfd93
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Emit;
using Riok.Mapperly.Emit.Syntax;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
Expand Down Expand Up @@ -35,7 +37,9 @@ ITypeSymbol objectType
public override MethodDeclarationSyntax BuildMethod(SourceEmitterContext ctx)
{
var methodSyntax = (MethodDeclarationSyntax)Method.DeclaringSyntaxReferences.First().GetSyntax();
return base.BuildMethod(ctx).WithTypeParameterList(methodSyntax.TypeParameterList);
return base.BuildMethod(ctx)
.WithTypeParameterList(methodSyntax.TypeParameterList)
.WithConstraintClauses(List(GetTypeParameterConstraintClauses()));
}

protected override ExpressionSyntax BuildTargetType()
Expand All @@ -44,6 +48,52 @@ protected override ExpressionSyntax BuildTargetType()
return TypeOfExpression(FullyQualifiedIdentifier(Method.ReturnType.NonNullable()));
}

protected virtual IEnumerable<TypeParameterConstraintClauseSyntax> GetTypeParameterConstraintClauses()
{
foreach (var tp in Method.TypeParameters)
{
var constraints = new List<TypeParameterConstraintSyntax>();

if (tp.HasUnmanagedTypeConstraint)
{
constraints.Add(TypeConstraint(IdentifierName("unmanaged")).AddLeadingSpace());
}
else if (tp.HasValueTypeConstraint)
{
constraints.Add(ClassOrStructConstraint(SyntaxKind.StructConstraint).AddLeadingSpace());
}
else if (tp.HasNotNullConstraint)
{
constraints.Add(TypeConstraint(IdentifierName("notnull")).AddLeadingSpace());
}
else if (tp.HasReferenceTypeConstraint)
{
constraints.Add(ClassOrStructConstraint(SyntaxKind.ClassConstraint).AddLeadingSpace());
}

foreach (var c in tp.ConstraintTypes)
{
constraints.Add(TypeConstraint(FullyQualifiedIdentifier(c)).AddLeadingSpace());
}

if (tp.HasConstructorConstraint)
{
constraints.Add(ConstructorConstraint().AddLeadingSpace());
}

if (!constraints.Any())
{
continue;
}

yield return TypeParameterConstraintClause(
IdentifierName(tp.Name).AddLeadingSpace().AddTrailingSpace(),
SeparatedList(constraints)
)
.AddLeadingSpace();
}
}

protected override ExpressionSyntax? BuildSwitchArmWhenClause(ExpressionSyntax targetType, RuntimeTargetTypeMapping mapping)
{
return mapping.IsAssignableToMethodTargetType ? null : base.BuildSwitchArmWhenClause(targetType, mapping);
Expand Down
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Emit/Syntax/SyntaxIndentationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public static TSyntax AddTrailingLineFeed<TSyntax>(this TSyntax syntax, int inde
return syntax.WithTrailingTrivia(trivia);
}

public static TSyntax AddLeadingSpace<TSyntax>(this TSyntax syntax)
where TSyntax : SyntaxNode => syntax.WithLeadingTrivia(syntax.GetLeadingTrivia().Add(ElasticSpace));

public static TSyntax AddTrailingSpace<TSyntax>(this TSyntax syntax)
where TSyntax : SyntaxNode => syntax.WithTrailingTrivia(syntax.GetTrailingTrivia().Add(ElasticSpace));

Expand Down
12 changes: 12 additions & 0 deletions test/Riok.Mapperly.Tests/GeneratedMethod.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Riok.Mapperly.Tests;
Expand All @@ -8,15 +9,26 @@ public GeneratedMethod(MethodDeclarationSyntax declarationSyntax)
{
Name = declarationSyntax.Identifier.ToString();
Signature = $"{declarationSyntax.ReturnType.ToString()} {Name}{declarationSyntax.ParameterList.ToString().Trim()}";
ConstraintClauses = ExtractParameterConstraints(declarationSyntax.ConstraintClauses);
Body = ExtractBody(declarationSyntax);
}

public string Name { get; }

public string Signature { get; }

public string? ConstraintClauses { get; }

public string Body { get; }

private static string? ExtractParameterConstraints(SyntaxList<TypeParameterConstraintClauseSyntax> typeParameterConstraints)
{
if (typeParameterConstraints.Count == 0)
return null;

return typeParameterConstraints.ToFullString().Trim(' ', '\r', '\n').ReplaceLineEndings();
}

/// <summary>
/// Builds the method body without the method body braces and without the method body level indentation.
/// </summary>
Expand Down
13 changes: 13 additions & 0 deletions test/Riok.Mapperly.Tests/MapperGenerationResultAssertions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ public MapperGenerationResultAssertions HaveMethodBody(string methodName, [Strin
public MapperGenerationResultAssertions HaveMapMethodBody([StringSyntax(StringSyntax.CSharp)] string mapperMethodBody) =>
HaveMethodBody(TestSourceBuilder.DefaultMapMethodName, mapperMethodBody);

public MapperGenerationResultAssertions HaveMapMethodWithGenericConstraints(
string methodName,
[StringSyntax(StringSyntax.CSharp)] string? constraintClauses
)
{
_mapper.Methods[methodName].ConstraintClauses.Should().Be(constraintClauses);
return this;
}

public MapperGenerationResultAssertions HaveMapMethodWithGenericConstraints(
[StringSyntax(StringSyntax.CSharp)] string? constraintClauses
) => HaveMapMethodWithGenericConstraints(TestSourceBuilder.DefaultMapMethodName, constraintClauses);

private IReadOnlyCollection<Diagnostic> GetDiagnostics(DiagnosticDescriptor descriptor)
{
if (_mapper.DiagnosticsByDescriptorId.TryGetValue(descriptor.Id, out var diagnostics))
Expand Down
73 changes: 54 additions & 19 deletions test/Riok.Mapperly.Tests/Mapping/GenericTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ public void WithGenericSource()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand All @@ -135,11 +136,14 @@ partial object Map<TSource>(TSource source)

partial B MapToB(A source);
partial D MapToD(C source);
partial F MapToF(E source);
""",
"record struct A(string Value);",
"record A(string Value);",
"record struct B(string Value);",
"record C(string Value1);",
"record D(string Value1);"
"record D(string Value1);",
"record E(string Value) : A(Value);",
"record struct F(string Value) : B(Value);"
);
TestHelper
.GenerateMapper(source)
Expand All @@ -148,13 +152,13 @@ partial object Map<TSource>(TSource source)
"""
return source switch
{
global::E x => MapToF(x),
global::A x => MapToB(x),
global::C x => MapToD(x),
null => throw new System.ArgumentNullException(nameof(source)),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : global::A");
}

[Fact]
Expand Down Expand Up @@ -185,7 +189,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : notnull");
}

[Fact]
Expand Down Expand Up @@ -215,7 +220,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : struct");
}

[Fact]
Expand Down Expand Up @@ -245,7 +251,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : unmanaged");
}

[Fact]
Expand Down Expand Up @@ -275,7 +282,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -306,7 +314,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -337,7 +346,8 @@ partial object Map<TSource>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(object)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class");
}

[Fact]
Expand Down Expand Up @@ -369,7 +379,8 @@ partial TTarget Map<TSource, TTarget>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : class where TTarget : class");
}

[Fact]
Expand Down Expand Up @@ -404,7 +415,8 @@ public void WithGenericSourceSpecificTarget()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(global::BaseDto)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -434,7 +446,8 @@ public void WithGenericTarget()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -464,7 +477,8 @@ partial TTarget Map<TTarget>(object source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TTarget : global::D");
}

[Fact]
Expand Down Expand Up @@ -498,7 +512,8 @@ public void WithGenericTargetSpecificSource()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
Expand Down Expand Up @@ -529,7 +544,8 @@ partial TTarget Map<TSource, TTarget>(TSource source)
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints("where TSource : global::C where TTarget : global::D");
}

[Fact]
Expand Down Expand Up @@ -562,7 +578,26 @@ public void WithUserImplementedMethodsShouldBeIncluded()
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
"""
);
)
.HaveMapMethodWithGenericConstraints(null);
}

[Fact]
public Task WithGenericConstructorConstraint()
{
var source = TestSourceBuilder.MapperWithBodyAndTypes(
"""
private partial TTarget Map<TSource, TTarget>(TSource source) where TSource : new() where TTarget : new();

private partial B MapToB(A source);
private partial D MapToD(C source);
""",
"record struct A(string Value) { public A() : this(default!) {} }",
"record struct B(string Value) { public B() : this(default!) {} }",
"record C(string Value1);",
"record D(string Value1);"
);
return TestHelper.VerifyGenerator(source);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//HintName: Mapper.g.cs
// <auto-generated />
#nullable enable
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial TTarget Map<TSource, TTarget>(TSource source) where TSource : new() where TTarget : new()
{
return source switch
{
global::A x when typeof(TTarget).IsAssignableFrom(typeof(global::B)) => (TTarget)(object)MapToB(x),
null => throw new System.ArgumentNullException(nameof(source)),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to {typeof(TTarget)} as there is no known type mapping", nameof(source)),
};
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::B MapToB(global::A source)
{
var target = new global::B();
target.Value = source.Value;
return target;
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial global::D MapToD(global::C source)
{
var target = new global::D(source.Value1);
return target;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
public partial class Mapper
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
private partial TTarget Map<TSource, TTarget, TSource2, TTarget2>(TSource source)
private partial TTarget Map<TSource, TTarget, TSource2, TTarget2>(TSource source) where TSource : global::System.Linq.IQueryable<TSource2> where TTarget : global::System.Linq.IQueryable<TTarget2>
{
return source switch
{
Expand Down

0 comments on commit d6cfd93

Please sign in to comment.