Skip to content

Commit

Permalink
Playing with an await visitor to rewrite code that contains nested aw…
Browse files Browse the repository at this point in the history
…aits into multiple blocks
  • Loading branch information
MattEdwardsWaggleBee committed Sep 6, 2024
1 parent 0baf252 commit 7d29926
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ public AsyncBlockExpression( ParameterExpression[] variables, Expression[] expre

_initialVariables = variables;
_expressions = expressions;

// var testing = new AwaitVisitor();
// foreach ( var expr in expressions )
// {
// testing.Visit( expr );
// }
// _expressions = testing.Expressions.ToArray();
}

public override Type Type
Expand Down
94 changes: 94 additions & 0 deletions src/Hyperbee.AsyncExpressions/AwaitVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System.Linq.Expressions;

namespace Hyperbee.AsyncExpressions;

public class AwaitVisitor : ExpressionVisitor
{
private readonly List<Expression> _expressions = [];
private int _variableCounter;

public IReadOnlyList<Expression> Expressions => _expressions;

protected override Expression VisitBinary( BinaryExpression node )
{
if ( node.Right is not AwaitExpression awaitExpression)
{
_expressions.Add( node );
return base.VisitBinary( node );
}

// Create a variable to hold the result of the Await expression
var variable = Expression.Variable( awaitExpression.Type, TempVariableName() );

// Create a new block with the variable and the await expression
var assignAwait = Expression.Assign( variable, awaitExpression );
var assignBlock = Expression.Block( [variable], assignAwait, variable );
var awaitBlock = AsyncExpression.Await( assignBlock, false );
_expressions.Add( awaitBlock );

var newAssignment = Expression.Assign( node.Left, variable );
_expressions.Add( newAssignment );

return base.VisitBinary( node );
}

protected override Expression VisitMethodCall( MethodCallExpression node )
{
var arguments = new List<Expression>();
var variables = new List<ParameterExpression>();

// Visit each argument in the method call
foreach ( var argument in node.Arguments )
{
if ( argument is AwaitExpression )
{
var variable = Expression.Variable( argument.Type, TempVariableName() );
var assign = Expression.Assign( variable, argument );
var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assign ) );
_expressions.Add( awaitBlock );

// Replace the AwaitExpression in the method call with the variable
arguments.Add( variable );
variables.Add( variable );
}
else
{
// If not an AwaitExpression, just add the original argument
arguments.Add( Visit( argument ) );
}
}

// Rewrite the method call
var updatedCall = node.Update( Visit( node.Object ), arguments );

// Create a new block that represents the rewritten method call
if ( variables.Count > 0 )
{
_expressions.Add( Expression.Block( variables, updatedCall ) );
}

return base.VisitMethodCall( node );
}


protected override Expression VisitConditional( ConditionalExpression node )
{
if ( node.Test is not AwaitExpression )
return base.VisitConditional( node );

// Create a variable to hold the result of the Await expression
var variable = Expression.Variable( node.Test.Type, TempVariableName() );

// Create a new block with the variable and the await expression
var assignAwait = Expression.Assign( variable, node.Test );
var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assignAwait ) );
_expressions.Add( awaitBlock );

var updateConditional = node.Update( assignAwait, node.IfTrue, node.IfFalse );
_expressions.Add( updateConditional );

return base.VisitConditional( node );
}

private string TempVariableName() => $"__var{_variableCounter++}";
}
2 changes: 1 addition & 1 deletion src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Linq.Expressions;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;

Expand Down
58 changes: 58 additions & 0 deletions test/Hyperbee.AsyncExpressions.Tests/AwaitVisitorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using System.Reflection;
using static System.Linq.Expressions.Expression;

namespace Hyperbee.AsyncExpressions.Tests;

[TestClass]
public class AwaitVisitorTests
{
static int Test( int a, int b ) => a + b;

[TestMethod]
public void ShouldFindAwait_WhenUsingCall()
{
// Arrange
var callExpr = Call(
typeof(AwaitVisitorTests).GetMethod( nameof(Test), BindingFlags.Static | BindingFlags.NonPublic )!,
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ),
AsyncExpression.Await( Constant( Task.FromResult( 2 ) ), false ) );

// Act
var visitor = new AwaitVisitor();
visitor.Visit( callExpr );

// Assert
Assert.AreEqual( 3, visitor.Expressions.Count );
}

[TestMethod]
public void ShouldFindAwait_WhenUsingAssign()
{
// Arrange
var varExpr = Variable( typeof(int), "x" );
var assignExpr = Assign( varExpr,
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) );

// Act
var visitor = new AwaitVisitor();
visitor.Visit( assignExpr );

// Assert
Assert.AreEqual( 2, visitor.Expressions.Count );
}

[TestMethod]
public void ShouldFindAwait_WhenUsingConditions()
{
// Arrange
var assignExpr = IfThen( AsyncExpression.Await( Constant( Task.FromResult( true ) ), false ),
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) );

// Act
var visitor = new AwaitVisitor();
visitor.Visit( assignExpr );

// Assert
Assert.AreEqual( 2, visitor.Expressions.Count );
}
}

0 comments on commit 7d29926

Please sign in to comment.