From 0378ba95c0c614c96e0a165aaf8dba3d6cd2bf77 Mon Sep 17 00:00:00 2001 From: Matt Edwards Date: Fri, 30 Aug 2024 17:55:57 -0400 Subject: [PATCH] Clean and add block --- .../AsyncBlockExpression.cs | 241 ++++++++++++++++++ .../AsyncInvocationExpression.cs | 14 +- .../AsyncInvokeExpression.cs | 26 +- .../AsyncMethodCallExpression.cs | 19 +- .../AwaitExpression.cs | 6 +- .../UnitTests.cs | 46 ++-- 6 files changed, 302 insertions(+), 50 deletions(-) create mode 100644 src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs diff --git a/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs new file mode 100644 index 0000000..437797a --- /dev/null +++ b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs @@ -0,0 +1,241 @@ +using System.Diagnostics; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; + +namespace Hyperbee.AsyncExpressions; + +[DebuggerDisplay( "{_body}" )] +[DebuggerTypeProxy( typeof( AsyncBlockExpressionProxy ) )] +public class AsyncBlockExpression : Expression +{ + private readonly Expression _body; + private Expression _reducedBody; + private bool _isReduced; + private static int _stateMachineCounter; + + private static readonly Expression VoidResult = Constant(Task.FromResult(new VoidTaskResult())); + + private static MethodInfo GenericGenerateExecuteAsync => typeof( AsyncInvokeExpression ) + .GetMethod( nameof( GenerateExecuteAsyncExpression ), BindingFlags.Static | BindingFlags.NonPublic ); + + internal AsyncBlockExpression( Expression body ) + { + ArgumentNullException.ThrowIfNull( body, nameof( body ) ); + + if ( !IsAsync( body.Type ) ) + throw new ArgumentException( $"The specified {nameof( body )} is not an async.", nameof( body ) ); + + _body = body; + } + + + public override ExpressionType NodeType => ExpressionType.Extension; + + public override Type Type => _body.Type; + + public override bool CanReduce => true; + + public override Expression Reduce() + { + if (_isReduced) + return _reducedBody; + + _isReduced = true; + + var (type, result) = GetTypeResult(_body); + var methodInfo = GenericGenerateExecuteAsync?.MakeGenericMethod(type); + + _reducedBody = (Expression)methodInfo!.Invoke(null, [result]); + + return _reducedBody!; + } + + private static (Type Type, Expression Expression) GetTypeResult(Expression expression) + { + return expression.Type == typeof(Task) + ? (typeof(VoidTaskResult), Block(expression, VoidResult)) + : (expression.Type.GetGenericArguments()[0], expression); + } + + + private static BlockExpression GenerateExecuteAsyncExpression( Expression task ) + { + // Generating code block: + /* + internal static Task ExecuteAsync(Task task) + { + var stateMachine = new StateMachine(task); + stateMachine.MoveNext(); + return stateMachine.Task; + } + */ + + // Create unique variable names to avoid conflicts + var id = Interlocked.Increment( ref _stateMachineCounter ); + var stateMachineVar = Variable( typeof(MultiTaskStateMachine ), $"stateMachine_{id}" ); + + // Constructor for state machine + var stateMachineCtor = typeof(MultiTaskStateMachine ) + .GetConstructor( [typeof( Task )] ); + + var assignStateMachine = Assign( + stateMachineVar, + New( stateMachineCtor!, task ) + ); + + // Call MoveNext + var moveNextMethod = typeof(MultiTaskStateMachine ).GetMethod( nameof(MultiTaskStateMachine.MoveNext ) ); + var moveNextCall = Call( stateMachineVar, moveNextMethod! ); + + // Return task property + var taskProperty = typeof(MultiTaskStateMachine ).GetProperty( nameof(MultiTaskStateMachine.Task ) ); + var returnTask = Property( stateMachineVar, taskProperty! ); + + // Explicitly use nested blocks to handle variable scoping + var resultBlock = Block( + [stateMachineVar], + assignStateMachine, + moveNextCall, + returnTask + ); + + return resultBlock; + } + + private struct MultiTaskStateMachine : IAsyncStateMachine + { + private readonly Task[] _tasks; + private readonly bool _isLastTaskGeneric; + private AsyncTaskMethodBuilder _builder; + private int _state; + + public MultiTaskStateMachine( Task[] tasks ) + { + _builder = AsyncTaskMethodBuilder.Create(); + _state = -1; + _tasks = tasks; + + // Determine if the last task is generic or not + var lastTaskType = tasks[^1].GetType(); + _isLastTaskGeneric = lastTaskType.IsGenericType && lastTaskType.GetGenericTypeDefinition() == typeof( Task<> ); + + SetStateMachine( this ); + } + + public Task Task => _builder.Task; + + public void MoveNext() + { + try + { + if ( _state == -1 ) + { + // Initial state: + _state = 0; + } + + if ( _state >= 0 && _state < _tasks.Length ) + { + var currentTask = _tasks[_state]; + + if ( _state == _tasks.Length - 1 && _isLastTaskGeneric ) + { + // Last task is generic + var genericAwaiter = ((Task) currentTask).ConfigureAwait( false ).GetAwaiter(); + if ( !genericAwaiter.IsCompleted ) + { + _builder.AwaitUnsafeOnCompleted( ref genericAwaiter, ref this ); + return; + } + + // Get the result directly if the task is already completed + var result = genericAwaiter.GetResult(); + _state = -2; + _builder.SetResult( result ); + } + else + { + // Intermediate non-generic task or last non-generic task + var awaiter = currentTask.ConfigureAwait( false ).GetAwaiter(); + if ( !awaiter.IsCompleted ) + { + _builder.AwaitUnsafeOnCompleted( ref awaiter, ref this ); + return; + } + + // Continue directly if the task is already completed + awaiter.GetResult(); + _state++; + MoveNext(); + } + } + else if ( _state == _tasks.Length && !_isLastTaskGeneric ) + { + // All tasks completed, last task was non-generic + _state = -2; + _builder.SetResult( default! ); + } + } + catch ( Exception ex ) + { + // Final state: error + _state = -2; + _builder.SetException( ex ); + } + } + + public void SetStateMachine( IAsyncStateMachine stateMachine ) + { + _builder.SetStateMachine( stateMachine ); + } + } + + private static bool IsAsync( Type returnType ) + { + return returnType == typeof( Task ) || + (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( Task<> )) || + (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( ValueTask<> )); + } + + public class AsyncBlockExpressionProxy( AsyncBlockExpression node ) + { + public Expression Body => node._body; + } + + public static AsyncBlockExpression BlockAsync( BlockExpression expression ) + { + //expression.Expressions.Count.. + + /* + { + + var result1 = { + [ex1Task] + expression1, //Task Assign( ex1Task, expression1 ) + expression2, + awaitExpression3 ( expression3 /// Expression ), + }, + + { + [ex1Task, result1] + await( ex1Task,void,T ) + } + + var result3 = { + [result2] + expression4, + } + ... + } + */ + + //var d = Task.Delay( 10 ); + // ... + //await d; + + + return new AsyncBlockExpression( expression ); + } + +} diff --git a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs index 80c3914..b04aefa 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs @@ -1,4 +1,5 @@ using System.Linq.Expressions; +using System.Reflection; namespace Hyperbee.AsyncExpressions; @@ -8,11 +9,14 @@ internal AsyncInvocationExpression( InvocationExpression body ) : base( body ) { } - public static AsyncInvokeExpression InvokeAsync( LambdaExpression lambdaExpression, params Expression[] arguments ) +} +public static partial class AsyncExpression +{ + public static AsyncInvokeExpression InvokeAsync(LambdaExpression lambdaExpression, params Expression[] arguments) { - if ( !IsAsync( lambdaExpression.ReturnType ) ) - throw new ArgumentException( "The specified lambda is not an async.", nameof( lambdaExpression ) ); + if (!AsyncInvokeExpression.IsAsync(lambdaExpression.ReturnType)) + throw new ArgumentException("The specified lambda is not an async.", nameof(lambdaExpression)); - return new AsyncInvokeExpression( Invoke( lambdaExpression, arguments ) ); + return new AsyncInvokeExpression(Expression.Invoke(lambdaExpression, arguments)); } -} +} \ No newline at end of file diff --git a/src/Hyperbee.AsyncExpressions/AsyncInvokeExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncInvokeExpression.cs index 2b8f8da..928b580 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncInvokeExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncInvokeExpression.cs @@ -31,7 +31,7 @@ internal AsyncInvokeExpression( Expression body ) public override ExpressionType NodeType => ExpressionType.Extension; - public override Type Type => _body.Type == typeof( Task ) ? typeof( Task ) : _body.Type; + public override Type Type => _body.Type; public override bool CanReduce => true; @@ -42,19 +42,19 @@ public override Expression Reduce() _isReduced = true; - var (type, result) = GetTypeResult( _body ); - var methodInfo = GenericGenerateExecuteAsync?.MakeGenericMethod( type ); - - _reducedBody = (Expression) methodInfo!.Invoke( null, [result] ); + _reducedBody = _body.Type.IsGenericType switch + { + true => GetReduceBody(_body.Type.GetGenericArguments()[0], _body), + false => GetReduceBody( typeof(VoidTaskResult), Block( _body, VoidResult ) ) + }; - return _reducedBody!; - } + return _reducedBody; - private static (Type Type, Expression Expression) GetTypeResult( Expression expression ) - { - return expression.Type == typeof( Task ) - ? (typeof( VoidTaskResult ), Block( expression, VoidResult )) - : (expression.Type.GetGenericArguments()[0], expression); + static Expression GetReduceBody( Type type, Expression body ) + { + var methodInfo = GenericGenerateExecuteAsync.MakeGenericMethod( type ); + return (Expression) methodInfo!.Invoke( null, [body] ); + } } private static BlockExpression GenerateExecuteAsyncExpression( Expression task ) @@ -167,7 +167,7 @@ public void SetStateMachine( IAsyncStateMachine stateMachine ) } } - protected static bool IsAsync( Type returnType ) + internal static bool IsAsync( Type returnType ) { return returnType == typeof( Task ) || (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( Task<> )) || diff --git a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs index b7d3354..2126405 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs @@ -8,22 +8,25 @@ public class AsyncMethodCallExpression : AsyncInvokeExpression internal AsyncMethodCallExpression( MethodCallExpression body ) : base( body ) { } +} + +public static partial class AsyncExpression +{ public static AsyncInvokeExpression CallAsync( MethodInfo methodInfo, params Expression[] arguments ) { - if ( !IsAsync( methodInfo.ReturnType ) ) - throw new ArgumentException( "The specified method is not an async.", nameof( methodInfo ) ); + if ( !AsyncInvokeExpression.IsAsync( methodInfo.ReturnType ) ) + throw new ArgumentException( "The specified method is not an async.", nameof(methodInfo) ); - return new AsyncInvokeExpression( Call( methodInfo, arguments ) ); + return new AsyncInvokeExpression( Expression.Call( methodInfo, arguments ) ); } public static AsyncInvokeExpression CallAsync( Expression instance, MethodInfo methodInfo, params Expression[] arguments ) { - if ( !IsAsync( methodInfo.ReturnType ) ) - throw new ArgumentException( "The specified method is not an async.", nameof( methodInfo ) ); + if ( !AsyncInvokeExpression.IsAsync( methodInfo.ReturnType ) ) + throw new ArgumentException( "The specified method is not an async.", nameof(methodInfo) ); - return new AsyncInvokeExpression( Call( instance, methodInfo, arguments ) ); + return new AsyncInvokeExpression( Expression.Call( instance, methodInfo, arguments ) ); } -} - +} \ No newline at end of file diff --git a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs index 2344241..1f96add 100644 --- a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs @@ -69,8 +69,12 @@ private static T Await( Task task, bool configureAwait ) return result; } +} + +public static partial class AsyncExpression +{ public static AwaitExpression Await(Expression expression, bool configureAwait) { return new AwaitExpression(expression, configureAwait); } -} +} \ No newline at end of file diff --git a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs index 76f5f7c..5f1d2b1 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs @@ -14,7 +14,7 @@ public class AsyncExpressionUnitTests { private static async Task Delay() { - await Task.Delay( 100 ); + await Task.Delay( 10 ); } private static async Task GetNumberAsync() @@ -63,10 +63,10 @@ private static AsyncInvokeExpression GetAsyncExpression( ExpressionKind kind, Me { case ExpressionKind.Lambda: var (lambdaExpression, parameters) = GetLambdaExpression( methodInfo, arguments ); - return AsyncInvocationExpression.InvokeAsync( lambdaExpression, parameters ); + return AsyncExpression.InvokeAsync( lambdaExpression, parameters ); case ExpressionKind.Method: - return AsyncMethodCallExpression.CallAsync( methodInfo, arguments ); + return AsyncExpression.CallAsync( methodInfo, arguments ); default: throw new ArgumentOutOfRangeException( nameof( kind ) ); @@ -96,7 +96,7 @@ public void TestAsyncExpression_NoParameters( ExpressionKind kind ) var methodInfo = GetMethodInfo( nameof( GetNumberAsync ) ); var asyncExpression = GetAsyncExpression( kind, methodInfo! ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpression ); var compiledLambda = lambda.Compile(); @@ -113,7 +113,7 @@ public void TestAsyncExpression_NoResults( ExpressionKind kind ) var methodInfo = GetMethodInfo( nameof( Delay ) ); var asyncExpression = GetAsyncExpression( kind, methodInfo! ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda( awaitExpression ); var compiledLambda = lambda.Compile(); @@ -133,7 +133,7 @@ public void TestAsyncExpression_WithParameters( ExpressionKind kind ) var paramExpr3 = Expression.Parameter( typeof( int ), "c" ); var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, paramExpr3 ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 ); var compiledLambda = lambda.Compile(); @@ -154,7 +154,7 @@ public void TestAsyncExpression_WithConstants( ExpressionKind kind ) var paramExpr3 = Expression.Constant( 12 ); var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, paramExpr3 ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpression ); var compiledLambda = lambda.Compile(); @@ -181,7 +181,7 @@ public void TestAsyncExpression_WithMethodParameters( ExpressionKind kind ) var incrementValueCall = Expression.Call( incrementMethodInfo!, paramExpr3 ); var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, incrementValueCall ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 ); var compiledLambda = lambda.Compile(); @@ -204,10 +204,10 @@ public void TestAsyncExpression_AsParameter( ExpressionKind kind ) var paramB = Expression.Parameter( typeof( int ), "b" ); var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod, paramA, paramB ); - var awaitExpressionAdd = AwaitExpression.Await( asyncExpressionAdd, configureAwait: false ); + var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false ); var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod, awaitExpressionAdd ); - var awaitExpressionSayHello = AwaitExpression.Await( asyncExpressionSayHello, configureAwait: false ); + var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpressionSayHello, paramA, paramB ); var compiledLambda = lambda.Compile(); @@ -232,8 +232,8 @@ public void TestMultipleAsyncExpressions_SeparateAwaits( ExpressionKind kind ) var asyncExpression1 = GetAsyncExpression( kind, methodInfo1! ); var asyncExpression2 = GetAsyncExpression( kind, methodInfo2!, paramExpr1, paramExpr2, paramExpr3 ); - var awaitExpression1 = AwaitExpression.Await( asyncExpression1, configureAwait: false ); - var awaitExpression2 = AwaitExpression.Await( asyncExpression2, configureAwait: false ); + var awaitExpression1 = AsyncExpression.Await( asyncExpression1, configureAwait: false ); + var awaitExpression2 = AsyncExpression.Await( asyncExpression2, configureAwait: false ); var lambda1 = Expression.Lambda>( awaitExpression1 ); var lambda2 = Expression.Lambda>( awaitExpression2, paramExpr1, paramExpr2, paramExpr3 ); @@ -260,7 +260,7 @@ public async Task TestScopedAwaitExpressions( ExpressionKind kind ) var paramB = Expression.Parameter( typeof( int ), "b" ); var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod!, paramA, paramB ); - var awaitExpressionAdd = AwaitExpression.Await( asyncExpressionAdd, configureAwait: false ); + var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false ); var resultFromAdd = Expression.Variable( typeof( int ), "resultFromAdd" ); @@ -286,7 +286,7 @@ public async Task TestScopedAwaitExpressions( ExpressionKind kind ) // Compile the nested expression into a lambda and execute it var lambda = Expression.Lambda>>( combinedExpression, paramA, paramB ); - var asyncLambda = AsyncInvocationExpression.InvokeAsync( lambda, paramA, paramB ); + var asyncLambda = AsyncExpression.InvokeAsync( lambda, paramA, paramB ); var compiledLambda = Expression.Lambda>>( asyncLambda, paramA, paramB ).Compile(); var result = await compiledLambda( 32, 10 ); // Execute with parameters 32 and 10 @@ -306,9 +306,9 @@ public async Task TestMultipleAsyncExpressions_WithDeepNestingAsync() // var l2 = Expression.Invoke( incrementExpression, l1 ); // var l3 = Expression.Invoke( incrementExpression, l2 ); - var l1 = AsyncInvocationExpression.InvokeAsync( incrementExpression, paramA ); - var l2 = AsyncInvocationExpression.InvokeAsync( incrementExpression, l1 ); - var l3 = AsyncInvocationExpression.InvokeAsync( incrementExpression, l2 ); + var l1 = AsyncExpression.InvokeAsync( incrementExpression, paramA ); + var l2 = AsyncExpression.InvokeAsync( incrementExpression, l1 ); + var l3 = AsyncExpression.InvokeAsync( incrementExpression, l2 ); var compiled = Expression.Lambda, Task>>( l3, paramA ).Compile(); var expressionResult = await compiled( Task.FromResult( 2 ) ); @@ -335,9 +335,9 @@ public async Task TestMultipleAsyncExpressions_WithDeepNestingAsyncAwait() var paramA = Expression.Parameter( typeof( int ), "a" ); - var l1 = AwaitExpression.Await( AsyncInvocationExpression.InvokeAsync( incrementExpression, paramA ), false ); - var l2 = AwaitExpression.Await( AsyncInvocationExpression.InvokeAsync( incrementExpression, l1 ), false ); - var l3 = AsyncInvocationExpression.InvokeAsync( incrementExpression, l2 ); + var l1 = AsyncExpression.Await(AsyncExpression.InvokeAsync( incrementExpression, paramA ), false ); + var l2 = AsyncExpression.Await(AsyncExpression.InvokeAsync( incrementExpression, l1 ), false ); + var l3 = AsyncExpression.InvokeAsync( incrementExpression, l2 ); var compiled = Expression.Lambda>>( l3, paramA ).Compile(); var expressionResult = await compiled( 2 ); @@ -370,13 +370,13 @@ public void TestChainedAwaitExpressions( ExpressionKind kind ) var paramB = Expression.Parameter( typeof( int ), "b" ); var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod!, paramA, paramB ); - var awaitExpressionAdd = AwaitExpression.Await( asyncExpressionAdd, configureAwait: false ); + var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false ); var resultFromAdd = Expression.Variable( typeof( int ), "resultFromAdd" ); // Create AsyncExpression and AwaitExpression for SayHello var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod!, resultFromAdd ); - var awaitExpressionSayHello = AwaitExpression.Await( asyncExpressionSayHello, configureAwait: false ); + var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false ); // Combine both expressions in a block var combinedExpression = Expression.Block( @@ -403,7 +403,7 @@ public void TestAsyncExpression_ExceptionHandling( ExpressionKind kind ) var methodInfo = GetMethodInfo( nameof( ThrowExceptionAsync ) ); var asyncExpression = GetAsyncExpression( kind, methodInfo! ); - var awaitExpression = AwaitExpression.Await( asyncExpression, configureAwait: false ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); var lambda = Expression.Lambda>( awaitExpression ); var compiledLambda = lambda.Compile();