diff --git a/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs index 2291c9a..bee194a 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs @@ -28,6 +28,13 @@ public AsyncBlockExpression( ParameterExpression[] variables, Expression[] expre _initialVariables = variables; _expressions = expressions; + + // var testing = new AwaitVisitor(); + // foreach ( var expr in expressions ) + // { + // testing.Visit( expr ); + // } + // _expressions = testing.Expressions.ToArray(); } public override Type Type diff --git a/src/Hyperbee.AsyncExpressions/AwaitVisitor.cs b/src/Hyperbee.AsyncExpressions/AwaitVisitor.cs new file mode 100644 index 0000000..1c66def --- /dev/null +++ b/src/Hyperbee.AsyncExpressions/AwaitVisitor.cs @@ -0,0 +1,94 @@ +using System.Linq.Expressions; + +namespace Hyperbee.AsyncExpressions; + +public class AwaitVisitor : ExpressionVisitor +{ + private readonly List _expressions = []; + private int _variableCounter; + + public IReadOnlyList Expressions => _expressions; + + protected override Expression VisitBinary( BinaryExpression node ) + { + if ( node.Right is not AwaitExpression awaitExpression) + { + _expressions.Add( node ); + return base.VisitBinary( node ); + } + + // Create a variable to hold the result of the Await expression + var variable = Expression.Variable( awaitExpression.Type, TempVariableName() ); + + // Create a new block with the variable and the await expression + var assignAwait = Expression.Assign( variable, awaitExpression ); + var assignBlock = Expression.Block( [variable], assignAwait, variable ); + var awaitBlock = AsyncExpression.Await( assignBlock, false ); + _expressions.Add( awaitBlock ); + + var newAssignment = Expression.Assign( node.Left, variable ); + _expressions.Add( newAssignment ); + + return base.VisitBinary( node ); + } + + protected override Expression VisitMethodCall( MethodCallExpression node ) + { + var arguments = new List(); + var variables = new List(); + + // Visit each argument in the method call + foreach ( var argument in node.Arguments ) + { + if ( argument is AwaitExpression ) + { + var variable = Expression.Variable( argument.Type, TempVariableName() ); + var assign = Expression.Assign( variable, argument ); + var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assign ) ); + _expressions.Add( awaitBlock ); + + // Replace the AwaitExpression in the method call with the variable + arguments.Add( variable ); + variables.Add( variable ); + } + else + { + // If not an AwaitExpression, just add the original argument + arguments.Add( Visit( argument ) ); + } + } + + // Rewrite the method call + var updatedCall = node.Update( Visit( node.Object ), arguments ); + + // Create a new block that represents the rewritten method call + if ( variables.Count > 0 ) + { + _expressions.Add( Expression.Block( variables, updatedCall ) ); + } + + return base.VisitMethodCall( node ); + } + + + protected override Expression VisitConditional( ConditionalExpression node ) + { + if ( node.Test is not AwaitExpression ) + return base.VisitConditional( node ); + + // Create a variable to hold the result of the Await expression + var variable = Expression.Variable( node.Test.Type, TempVariableName() ); + + // Create a new block with the variable and the await expression + var assignAwait = Expression.Assign( variable, node.Test ); + var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assignAwait ) ); + _expressions.Add( awaitBlock ); + + var updateConditional = node.Update( assignAwait, node.IfTrue, node.IfFalse ); + _expressions.Add( updateConditional ); + + return base.VisitConditional( node ); + } + + private string TempVariableName() => $"__var{_variableCounter++}"; +} diff --git a/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs b/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs index b1885e3..0b026ac 100644 --- a/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs +++ b/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs @@ -1,4 +1,4 @@ -using System.Linq.Expressions; +using System.Linq.Expressions; using System.Reflection; using System.Reflection.Emit; diff --git a/test/Hyperbee.AsyncExpressions.Tests/AwaitVisitorTests.cs b/test/Hyperbee.AsyncExpressions.Tests/AwaitVisitorTests.cs new file mode 100644 index 0000000..183f868 --- /dev/null +++ b/test/Hyperbee.AsyncExpressions.Tests/AwaitVisitorTests.cs @@ -0,0 +1,58 @@ +using System.Reflection; +using static System.Linq.Expressions.Expression; + +namespace Hyperbee.AsyncExpressions.Tests; + +[TestClass] +public class AwaitVisitorTests +{ + static int Test( int a, int b ) => a + b; + + [TestMethod] + public void ShouldFindAwait_WhenUsingCall() + { + // Arrange + var callExpr = Call( + typeof(AwaitVisitorTests).GetMethod( nameof(Test), BindingFlags.Static | BindingFlags.NonPublic )!, + AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ), + AsyncExpression.Await( Constant( Task.FromResult( 2 ) ), false ) ); + + // Act + var visitor = new AwaitVisitor(); + visitor.Visit( callExpr ); + + // Assert + Assert.AreEqual( 3, visitor.Expressions.Count ); + } + + [TestMethod] + public void ShouldFindAwait_WhenUsingAssign() + { + // Arrange + var varExpr = Variable( typeof(int), "x" ); + var assignExpr = Assign( varExpr, + AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) ); + + // Act + var visitor = new AwaitVisitor(); + visitor.Visit( assignExpr ); + + // Assert + Assert.AreEqual( 2, visitor.Expressions.Count ); + } + + [TestMethod] + public void ShouldFindAwait_WhenUsingConditions() + { + // Arrange + var assignExpr = IfThen( AsyncExpression.Await( Constant( Task.FromResult( true ) ), false ), + AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) ); + + // Act + var visitor = new AwaitVisitor(); + visitor.Visit( assignExpr ); + + // Assert + Assert.AreEqual( 2, visitor.Expressions.Count ); + } +}