From b910c1811bb76fd5c8b50da2539590ce373c85ca Mon Sep 17 00:00:00 2001 From: Timothy Makkison Date: Mon, 10 Jul 2023 17:17:09 +0100 Subject: [PATCH] fix: add `GetAttributesCore`, modify `GetAttributes` and `HasAttribute` --- .../Configuration/AttributeDataAccessor.cs | 5 +-- .../Descriptors/SymbolAccessor.cs | 36 ++++++++++--------- .../Descriptors/WellKnownTypes.cs | 4 +-- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs b/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs index b59f961fa3..f6d0080a83 100644 --- a/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs +++ b/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs @@ -44,11 +44,8 @@ public IEnumerable Access(ISymbol symbol) { var attrType = typeof(TAttribute); var dataType = typeof(TData); - var attrSymbol = _types.Get($"{attrType.Namespace}.{attrType.Name}"); - var attrDatas = _symbolAccessor - .GetAttributes(symbol) - .Where(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass?.ConstructedFrom ?? x.AttributeClass, attrSymbol)); + var attrDatas = _symbolAccessor.GetAttributes(symbol); foreach (var attrData in attrDatas) { diff --git a/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs b/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs index 33a316e634..858c4503c4 100644 --- a/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs +++ b/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs @@ -18,26 +18,17 @@ public SymbolAccessor(WellKnownTypes types) _types = types; } - internal ImmutableArray GetAttributes(ISymbol symbol) - { - if (_attributes.TryGetValue(symbol, out var attributes)) - { - return attributes; - } - - attributes = symbol.GetAttributes(); - _attributes.Add(symbol, attributes); - - return attributes; - } - - internal bool HasAttribute(ISymbol symbol) + internal IEnumerable GetAttributes(ISymbol symbol) where T : Attribute { var attributeSymbol = _types.Get(typeof(T)); - return GetAttributes(symbol).Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeSymbol)); + return GetAttributesCore(symbol) + .Where(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass?.ConstructedFrom ?? x.AttributeClass, attributeSymbol)); } + internal bool HasAttribute(ISymbol symbol) + where T : Attribute => GetAttributes(symbol).Any(); + internal IEnumerable GetAllMethods(ITypeSymbol symbol) => GetAllMembers(symbol).OfType(); internal IEnumerable GetAllMethods(ITypeSymbol symbol, string name) => @@ -81,6 +72,19 @@ internal IEnumerable GetMappableMembers(ITypeSymbol symbol, str private IEnumerable GetAllMembers(ITypeSymbol symbol, string name) => GetAllMembers(symbol).Where(x => name.Equals(x.Name)); + private ImmutableArray GetAttributesCore(ISymbol symbol) + { + if (_attributes.TryGetValue(symbol, out var attributes)) + { + return attributes; + } + + attributes = symbol.GetAttributes(); + _attributes.Add(symbol, attributes); + + return attributes; + } + private IEnumerable GetAllMembersCore(ITypeSymbol symbol) { var members = symbol.GetMembers(); @@ -97,7 +101,7 @@ private IEnumerable GetAllMembersCore(ITypeSymbol symbol) private IEnumerable GetAllAccessibleMappableMembersCore(ITypeSymbol symbol) { return GetAllMembers(symbol) - .Where(x => !x.IsStatic && x.IsAccessible() && x.Kind is SymbolKind.Property or SymbolKind.Field) + .Where(x => x is { IsStatic: false, Kind: SymbolKind.Property or SymbolKind.Field } && x.IsAccessible()) .DistinctBy(x => x.Name) .Select(MappableMember.Create) .WhereNotNull(); diff --git a/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs b/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs index 8ddf15f8b3..0df2d8bae6 100644 --- a/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs +++ b/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs @@ -21,10 +21,10 @@ internal WellKnownTypes(Compilation compilation) public ITypeSymbol GetArrayType(ITypeSymbol type) => _compilation.CreateArrayTypeSymbol(type, elementNullableAnnotation: type.NullableAnnotation).NonNullable(); - public INamedTypeSymbol Get() => Get(typeof(T).FullName); + public INamedTypeSymbol Get() => Get(typeof(T)); public INamedTypeSymbol Get(Type type) => - Get(type.FullName ?? throw new InvalidOperationException("Could not get name of type " + type)); + Get($"{type.Namespace}.{type.Name}" ?? throw new InvalidOperationException("Could not get name of type " + type)); public INamedTypeSymbol Get(string typeFullName) => TryGet(typeFullName) ?? throw new InvalidOperationException("Could not get type " + typeFullName);