-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Playing with an await visitor to rewrite code that contains nested aw…
…aits into multiple blocks
- Loading branch information
1 parent
0baf252
commit 7d29926
Showing
4 changed files
with
160 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
using System.Linq.Expressions; | ||
|
||
namespace Hyperbee.AsyncExpressions; | ||
|
||
public class AwaitVisitor : ExpressionVisitor | ||
{ | ||
private readonly List<Expression> _expressions = []; | ||
private int _variableCounter; | ||
|
||
public IReadOnlyList<Expression> Expressions => _expressions; | ||
|
||
protected override Expression VisitBinary( BinaryExpression node ) | ||
{ | ||
if ( node.Right is not AwaitExpression awaitExpression) | ||
{ | ||
_expressions.Add( node ); | ||
return base.VisitBinary( node ); | ||
} | ||
|
||
// Create a variable to hold the result of the Await expression | ||
var variable = Expression.Variable( awaitExpression.Type, TempVariableName() ); | ||
|
||
// Create a new block with the variable and the await expression | ||
var assignAwait = Expression.Assign( variable, awaitExpression ); | ||
var assignBlock = Expression.Block( [variable], assignAwait, variable ); | ||
var awaitBlock = AsyncExpression.Await( assignBlock, false ); | ||
_expressions.Add( awaitBlock ); | ||
|
||
var newAssignment = Expression.Assign( node.Left, variable ); | ||
_expressions.Add( newAssignment ); | ||
|
||
return base.VisitBinary( node ); | ||
} | ||
|
||
protected override Expression VisitMethodCall( MethodCallExpression node ) | ||
{ | ||
var arguments = new List<Expression>(); | ||
var variables = new List<ParameterExpression>(); | ||
|
||
// Visit each argument in the method call | ||
foreach ( var argument in node.Arguments ) | ||
{ | ||
if ( argument is AwaitExpression ) | ||
{ | ||
var variable = Expression.Variable( argument.Type, TempVariableName() ); | ||
var assign = Expression.Assign( variable, argument ); | ||
var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assign ) ); | ||
_expressions.Add( awaitBlock ); | ||
|
||
// Replace the AwaitExpression in the method call with the variable | ||
arguments.Add( variable ); | ||
variables.Add( variable ); | ||
} | ||
else | ||
{ | ||
// If not an AwaitExpression, just add the original argument | ||
arguments.Add( Visit( argument ) ); | ||
} | ||
} | ||
|
||
// Rewrite the method call | ||
var updatedCall = node.Update( Visit( node.Object ), arguments ); | ||
|
||
// Create a new block that represents the rewritten method call | ||
if ( variables.Count > 0 ) | ||
{ | ||
_expressions.Add( Expression.Block( variables, updatedCall ) ); | ||
} | ||
|
||
return base.VisitMethodCall( node ); | ||
} | ||
|
||
|
||
protected override Expression VisitConditional( ConditionalExpression node ) | ||
{ | ||
if ( node.Test is not AwaitExpression ) | ||
return base.VisitConditional( node ); | ||
|
||
// Create a variable to hold the result of the Await expression | ||
var variable = Expression.Variable( node.Test.Type, TempVariableName() ); | ||
|
||
// Create a new block with the variable and the await expression | ||
var assignAwait = Expression.Assign( variable, node.Test ); | ||
var awaitBlock = AsyncExpression.Awaitable( Expression.Block( [variable], assignAwait ) ); | ||
_expressions.Add( awaitBlock ); | ||
|
||
var updateConditional = node.Update( assignAwait, node.IfTrue, node.IfFalse ); | ||
_expressions.Add( updateConditional ); | ||
|
||
return base.VisitConditional( node ); | ||
} | ||
|
||
private string TempVariableName() => $"__var{_variableCounter++}"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
using System.Reflection; | ||
using static System.Linq.Expressions.Expression; | ||
|
||
namespace Hyperbee.AsyncExpressions.Tests; | ||
|
||
[TestClass] | ||
public class AwaitVisitorTests | ||
{ | ||
static int Test( int a, int b ) => a + b; | ||
|
||
[TestMethod] | ||
public void ShouldFindAwait_WhenUsingCall() | ||
{ | ||
// Arrange | ||
var callExpr = Call( | ||
typeof(AwaitVisitorTests).GetMethod( nameof(Test), BindingFlags.Static | BindingFlags.NonPublic )!, | ||
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ), | ||
AsyncExpression.Await( Constant( Task.FromResult( 2 ) ), false ) ); | ||
|
||
// Act | ||
var visitor = new AwaitVisitor(); | ||
visitor.Visit( callExpr ); | ||
|
||
// Assert | ||
Assert.AreEqual( 3, visitor.Expressions.Count ); | ||
} | ||
|
||
[TestMethod] | ||
public void ShouldFindAwait_WhenUsingAssign() | ||
{ | ||
// Arrange | ||
var varExpr = Variable( typeof(int), "x" ); | ||
var assignExpr = Assign( varExpr, | ||
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) ); | ||
|
||
// Act | ||
var visitor = new AwaitVisitor(); | ||
visitor.Visit( assignExpr ); | ||
|
||
// Assert | ||
Assert.AreEqual( 2, visitor.Expressions.Count ); | ||
} | ||
|
||
[TestMethod] | ||
public void ShouldFindAwait_WhenUsingConditions() | ||
{ | ||
// Arrange | ||
var assignExpr = IfThen( AsyncExpression.Await( Constant( Task.FromResult( true ) ), false ), | ||
AsyncExpression.Await( Constant( Task.FromResult( 1 ) ), false ) ); | ||
|
||
// Act | ||
var visitor = new AwaitVisitor(); | ||
visitor.Visit( assignExpr ); | ||
|
||
// Assert | ||
Assert.AreEqual( 2, visitor.Expressions.Count ); | ||
} | ||
} |