Skip to content

Commit

Permalink
Enhanced LINQ query handling in Visitor class
Browse files Browse the repository at this point in the history
- Added and adjusted using directives for expanded functionality.
- Implemented support for `Queryable` methods like `Where` and `Select` in `ContentItemQueryExpressionVisitor`.
- Introduced `ExtractValues` for better value extraction from collections.
- Improved `GetMemberValue` for null safety and robust value retrieval.
- Added `AddWhereInCondition` for simplified `WhereIn` condition addition.
- New methods for collection and field value extraction enhance complex query support.
- Implemented `GetExpressionValue` for constant and member expression value retrieval.
- Enhanced method call processing for `Contains`, `StartsWith`, etc., with better null safety.
- General code improvements for null safety, additional LINQ method support, and better error handling.
  • Loading branch information
bluemodus-brandon committed Jul 12, 2024
1 parent 19380e7 commit e07cd81
Showing 1 changed file with 226 additions and 24 deletions.
250 changes: 226 additions & 24 deletions src/XperienceCommunity.DataContext/ContentItemQueryExpressionVisitor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using System.Collections;
using System.Linq.Expressions;
using System.Reflection;
using CMS.ContentEngine;
using XperienceCommunity.DataContext.Extensions;
Expand Down Expand Up @@ -110,6 +111,24 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
throw new NotSupportedException($"The method '{node.Method.Name}' is not supported.");
}
}
else if (node.Method.DeclaringType == typeof(Queryable))
{
switch (node.Method.Name)
{
case nameof(Queryable.Where):
return ProcessQueryableWhere(node);

case nameof(Queryable.Select):
return ProcessQueryableSelect(node);
// Add other Queryable methods as needed
default:
throw new NotSupportedException($"The method call '{node.Method.Name}' is not supported.");
}
}
else if (node.Method.Name == nameof(Enumerable.Contains))
{
ProcessEnumerableContains(node);
}
else
{
throw new NotSupportedException($"The method '{node.Method.Name}' is not supported.");
Expand All @@ -132,7 +151,72 @@ protected override Expression VisitUnary(UnaryExpression node)
return node;
}

private static object? GetMemberValue(Expression expression)
private static IEnumerable<object> ExtractValues(object? value)
{
if (value is IEnumerable<object> objectEnumerable)
{
return objectEnumerable;
}

if (value is IEnumerable<int> intEnumerable)
{
return intEnumerable.Cast<object>();
}

if (value is IEnumerable<string> stringEnumerable)
{
return stringEnumerable.Cast<object>();
}

if (value is IEnumerable<Guid> guidEnumerable)
{
return guidEnumerable.Cast<object>();
}

if (value is IEnumerable enumerable)
{
var list = new List<object>();

foreach (var item in enumerable)
{
var itemValues = ExtractValues(item);
list.AddRange(itemValues);
}

return list;
}

if (value is null)
{
return [];
}

// Check if the object has a property that is a collection
var properties = value.GetType().GetProperties();

var collectionProperty = properties.FirstOrDefault(p =>
p.PropertyType.IsGenericType && p.PropertyType.GetGenericTypeDefinition() == typeof(IEnumerable<>));

if (collectionProperty != null)
{
var collectionValue = collectionProperty.GetValue(value);
return ExtractValues(collectionValue);
}

return new[] { value };
}

private static string? GetMemberNameFromMethodCall(MethodCallExpression methodCall)
{
if (methodCall.Object is MemberExpression memberExpression)
{
return memberExpression.Member.Name;
}

return null;
}

private static object? GetMemberValue(Expression? expression)
{
switch (expression)
{
Expand All @@ -141,15 +225,18 @@ protected override Expression VisitUnary(UnaryExpression node)

case MemberExpression memberExpression:
var member = memberExpression.Member;

var objectValue =
GetMemberValue(memberExpression.Expression!); // Recursively process the expression
GetMemberValue(memberExpression?.Expression); // Recursively process the expression

if (objectValue == null)
{
throw new InvalidOperationException("The target object for the member expression is null.");
}

return member switch
{
System.Reflection.FieldInfo fieldInfo => fieldInfo.GetValue(objectValue),
FieldInfo fieldInfo => fieldInfo.GetValue(objectValue),
PropertyInfo propertyInfo => propertyInfo.GetValue(objectValue),
_ => throw new NotSupportedException(
$"The member type '{member.GetType().Name}' is not supported.")
Expand All @@ -164,18 +251,97 @@ protected override Expression VisitUnary(UnaryExpression node)

default:
throw new NotSupportedException(
$"The expression type '{expression.GetType().Name}' is not supported.");
$"The expression type '{expression?.GetType().Name}' is not supported.");
}
}

private static object? GetMethodCallValue(MethodCallExpression methodCall)
{
// Evaluate the method call expression to get the resulting value
var lambda = Expression.Lambda(methodCall).Compile();

return lambda.DynamicInvoke();
}

private void AddWhereInCondition(string columnName, IEnumerable<object>? values)
{
if (values == null)
{
return;
}

if (!values.Any())
{
// Pass an empty array to WhereIn
_queryParameters.Where(where => where.WhereIn(columnName, Array.Empty<string>()));
return;
}

var firstValue = values.First();

if (firstValue is int)
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<int>().ToArray()));
}
else if (firstValue is string)
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<string>().ToArray()));
}
else if (firstValue is Guid)
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<Guid>().ToArray()));
}
else
{
return;
}
}

private IEnumerable<object>? ExtractCollectionValues(MemberExpression collectionExpression)
{
if (collectionExpression.Expression != null)
{
var value = GetExpressionValue(collectionExpression.Expression);

return ExtractValues(value);
}

return null;
}

private IEnumerable<object> ExtractFieldValues(MemberExpression fieldExpression)
{
var value = GetExpressionValue(fieldExpression);
return ExtractValues(value);
}

private object? GetExpressionValue(Expression expression)
{
switch (expression)
{
case ConstantExpression constantExpression:
return constantExpression.Value!;

case MemberExpression memberExpression:
var container = GetExpressionValue(memberExpression.Expression!);
var member = memberExpression.Member;
switch (member)
{
case FieldInfo fieldInfo:
return fieldInfo.GetValue(container);

case PropertyInfo propertyInfo:
return propertyInfo.GetValue(container);

default:
throw new NotSupportedException(
$"The member type '{member.GetType().Name}' is not supported.");
}
default:
throw new NotSupportedException(
$"The expression type '{expression.GetType().Name}' is not supported.");
}
}

private void ProcessComparison(BinaryExpression node, bool isGreaterThan, bool isEqual = false)
{
if (node.Left is MemberExpression left)
Expand Down Expand Up @@ -294,7 +460,7 @@ private void ProcessComparison(BinaryExpression node, bool isGreaterThan, bool i
}
else if (node.Left is MethodCallExpression leftMethod && node.Right is ConstantExpression rightConst)
{
var memberName = leftMethod.GetMemberNameFromMethodCall();
var memberName = GetMemberNameFromMethodCall(leftMethod);

if (memberName != null)
{
Expand Down Expand Up @@ -336,24 +502,28 @@ private void ProcessComparison(BinaryExpression node, bool isGreaterThan, bool i

private void ProcessEnumerableContains(MethodCallExpression node)
{
if (node.Arguments[0] is MemberExpression memberExpression &&
if (node.Arguments.Count == 2 &&
node.Arguments[0] is MemberExpression memberExpression &&
node.Arguments[1] is ConstantExpression listExpression)
{
var columnName = memberExpression.Member.Name;
var values = (IEnumerable<object>)listExpression.Value!;
var values = ExtractValues(listExpression.Value);

if (listExpression.Type.GenericTypeArguments[0] == typeof(int))
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<int>().ToArray()));
}
else if (listExpression.Type.GenericTypeArguments[0] == typeof(string))
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<string>().ToArray()));
}
else if (listExpression.Type.GenericTypeArguments[0] == typeof(Guid))
{
_queryParameters.Where(where => where.WhereIn(columnName, values.Cast<Guid>().ToArray()));
}
AddWhereInCondition(columnName, values);
}
else if (node.Arguments.Count == 2 &&
node.Arguments[0] is MemberExpression collectionExpression &&
node.Arguments[1] is MemberExpression itemExpression)
{
var collection = ExtractFieldValues(collectionExpression);
var columnName = itemExpression.Member.Name;

AddWhereInCondition(columnName, collection);
}
else
{
throw new NotSupportedException(
$"The expression types '{node.Arguments[0]?.GetType().Name}' and '{node.Arguments[1]?.GetType().Name}' are not supported.");
}
}

Expand Down Expand Up @@ -396,7 +566,7 @@ private void ProcessEquality(BinaryExpression node)
throw new NotSupportedException(
$"The left expression type '{node.Left.GetType().Name}' is not supported.");
}
}
}

private void ProcessLogicalAnd(BinaryExpression node)
{
Expand Down Expand Up @@ -454,19 +624,51 @@ private void ProcessNotEquality(BinaryExpression node)
}
}

private Expression ProcessQueryableSelect(MethodCallExpression node)
{
if (node.Arguments[1] is UnaryExpression unaryExpression &&
unaryExpression.Operand is LambdaExpression lambdaExpression)
{
Visit(lambdaExpression.Body);
}
else
{
throw new NotSupportedException(
$"The expression type '{node.Arguments[1].GetType().Name}' is not supported.");
}

return node;
}

private Expression ProcessQueryableWhere(MethodCallExpression node)
{
if (node.Arguments[1] is UnaryExpression unaryExpression &&
unaryExpression.Operand is LambdaExpression lambdaExpression)
{
Visit(lambdaExpression.Body);
}
else
{
throw new NotSupportedException(
$"The expression type '{node.Arguments[1].GetType().Name}' is not supported.");
}

return node;
}

private void ProcessStringContains(MethodCallExpression node)
{
if (node.Object is MemberExpression member && node.Arguments[0] is ConstantExpression constant)
{
_queryParameters.Where(where => where.WhereContains(member.Member.Name, constant.Value?.ToString()));
_queryParameters.Where(where => where.WhereContains(member.Member.Name, constant?.Value?.ToString()));
}
}

private void ProcessStringStartsWith(MethodCallExpression node)
{
if (node.Object is MemberExpression member && node.Arguments[0] is ConstantExpression constant)
{
_queryParameters.Where(where => where.WhereStartsWith(member.Member.Name, constant.Value?.ToString()));
_queryParameters.Where(where => where.WhereStartsWith(member.Member.Name, constant?.Value?.ToString()));
}
}
}
Expand Down

0 comments on commit e07cd81

Please sign in to comment.