diff --git a/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs index a3791d1..2e040fe 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs @@ -1,6 +1,8 @@ -using System.Diagnostics; +using System.Diagnostics; using System.Linq.Expressions; using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; namespace Hyperbee.AsyncExpressions; @@ -23,7 +25,7 @@ protected AsyncBaseExpression( Expression body ) protected abstract Type GetFinalResultType(); - protected abstract Expression BuildStateMachine(); + protected abstract void ConfigureStateMachine( StateMachineBuilder builder ); public override Expression Reduce() { @@ -31,10 +33,11 @@ public override Expression Reduce() return _stateMachineBody; var finalResultType = GetFinalResultType(); - + var stateMachineResultType = finalResultType == typeof(void) ? typeof(VoidResult) : finalResultType; + var buildStateMachine = typeof(AsyncBaseExpression) .GetMethod( nameof(BuildStateMachine), BindingFlags.NonPublic | BindingFlags.Instance )! - .MakeGenericMethod( finalResultType ); + .MakeGenericMethod( stateMachineResultType ); _stateMachineBody = (Expression) buildStateMachine.Invoke( this, null ); _isReduced = true; @@ -42,6 +45,34 @@ public override Expression Reduce() return _stateMachineBody!; } + private MethodCallExpression BuildStateMachine() + { + // Create a dynamic assembly and module for the state machine + var assemblyName = new AssemblyName( "DynamicStateMachineAssembly" ); + var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly( assemblyName, AssemblyBuilderAccess.Run ); + var moduleBuilder = assemblyBuilder.DefineDynamicModule( "MainModule" ); + + // Create a state machine builder + var stateMachineBuilder = new StateMachineBuilder( moduleBuilder, "DynamicStateMachine" ); + + // Delegate to the derived class to configure the builder + ConfigureStateMachine( stateMachineBuilder ); + + // Create the state machine type + var stateMachineType = stateMachineBuilder.CreateStateMachineType(); + + // Create a proxy expression for handling MoveNext and SetStateMachine calls + var proxyConstructor = typeof(StateMachineProxy).GetConstructor( new[] { typeof(IAsyncStateMachine) } ); + var stateMachineInstance = Expression.New( stateMachineType ); + var proxyInstance = Expression.New( proxyConstructor!, stateMachineInstance ); + + // Build an expression that represents invoking the MoveNext method on the proxy + var moveNextMethod = typeof(IAsyncStateMachine).GetMethod( nameof(IAsyncStateMachine.MoveNext) ); + var moveNextCall = Expression.Call( proxyInstance, moveNextMethod! ); + + return moveNextCall; + } + internal static bool IsTask( Type returnType ) { return returnType == typeof(Task) || diff --git a/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs index 616227d..8449825 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs @@ -4,78 +4,77 @@ namespace Hyperbee.AsyncExpressions; public class AsyncBlockExpression : AsyncBaseExpression { - private readonly Expression[] _expressions; + private readonly BlockExpression _reducedBlock; + private readonly Type _finalResultType; - public AsyncBlockExpression( Expression[] expressions) : base(null) + public AsyncBlockExpression( Expression[] expressions ) : base( null ) { - _expressions = expressions; + if ( expressions == null || expressions.Length == 0 ) + { + throw new ArgumentException( "AsyncBlockExpression must contain at least one expression.", nameof(expressions) ); + } + + _reducedBlock = ReduceBlock( expressions, out _finalResultType ); } protected override Type GetFinalResultType() { - // Get the final result type from the last block - var (_, finalResultType) = ReduceBlock(_expressions); - return finalResultType; + return _finalResultType; + } + + protected override void ConfigureStateMachine( StateMachineBuilder builder ) + { + builder.GenerateMoveNextMethod( _reducedBlock ); } - protected override Expression BuildStateMachine() + private static BlockExpression ReduceBlock( Expression[] expressions, out Type finalResultType ) { - var (blocks, finalResultType) = ReduceBlock(_expressions); + var parentBlockExpressions = new List(); + var currentBlockExpressions = new List(); + var awaitEncountered = false; - var builder = new StateMachineBuilder(); + // Collect all variables declared in the block + var variables = new HashSet(); + finalResultType = typeof(void); // Default to void, adjust if task found - foreach (var block in blocks) + foreach ( var expr in expressions ) { - var lastExpr = block.Expressions.Last(); - if (IsTask(lastExpr.Type)) - { - if (lastExpr.Type == typeof(Task)) - { - builder.AddTaskBlock(block); // Block with Task - } - else if (lastExpr.Type.IsGenericType && lastExpr.Type.GetGenericTypeDefinition() == typeof(Task<>)) - { - builder.AddTaskResultBlock(block); // Block with Task - } - } - else + if ( expr is AsyncBlockExpression asyncBlock ) { - builder.AddBlock(block); // Regular code block + // Recursively reduce the inner async block + var reducedInnerBlock = asyncBlock.Reduce(); + currentBlockExpressions.Add( reducedInnerBlock ); + continue; } - } - - return builder.Build(); - } - // ReduceBlock method to split the block into sub-blocks - private (List blocks, Type finalResultType) ReduceBlock( Expression[] expressions) - { - var blocks = new List(); - var currentBlock = new List(); - Type finalResultType = typeof(void); - - foreach (var expr in expressions) - { - currentBlock.Add(expr); + currentBlockExpressions.Add( expr ); - if (expr is AwaitExpression) + switch ( expr ) { - // Finalize the current block and add it to the list - blocks.Add(Block(currentBlock)); - currentBlock.Clear(); + case BinaryExpression binaryExpr when binaryExpr.Left is ParameterExpression varExpr: + variables.Add( varExpr ); + break; + case AwaitExpression: + { + awaitEncountered = true; + var currentBlock = Block( currentBlockExpressions ); + parentBlockExpressions.Add( currentBlock ); + currentBlockExpressions = []; + break; + } } } - // Add the last block if it exists - if (currentBlock.Count > 0) + if ( currentBlockExpressions.Count > 0 ) { - blocks.Add(Block(currentBlock)); - var lastExpr = currentBlock.Last(); + var finalBlock = Block( currentBlockExpressions ); + parentBlockExpressions.Add( finalBlock ); - // Determine the final result type from the last expression - if (IsTask(lastExpr.Type)) + // Update the final result type based on the last expression in the final block + var lastExpr = currentBlockExpressions.Last(); + if ( IsTask( lastExpr.Type ) ) { - if (lastExpr.Type.IsGenericType) + if ( lastExpr.Type.IsGenericType ) { finalResultType = lastExpr.Type.GetGenericArguments()[0]; } @@ -86,7 +85,13 @@ protected override Expression BuildStateMachine() } } - return (blocks, finalResultType); + if ( !awaitEncountered ) + { + throw new InvalidOperationException( $"{nameof(AsyncBlockExpression)} must contain at least one {nameof(AwaitExpression)}." ); + } + + // Combine all child blocks into a single parent block, with variables declared at the parent level + return Block( variables, parentBlockExpressions ); // Declare variables only once at the top level } } @@ -97,3 +102,4 @@ public static AsyncBaseExpression BlockAsync( params Expression[] expressions ) return new AsyncBlockExpression( expressions ); } } + diff --git a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs index 952546c..4669c57 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs @@ -17,35 +17,22 @@ protected override Type GetFinalResultType() { return typeof(void); // No result to return } - else if ( _invocationExpression.Type.IsGenericType && _invocationExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) ) + + if ( _invocationExpression.Type.IsGenericType && _invocationExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) ) { return _invocationExpression.Type.GetGenericArguments()[0]; // Return T from Task } - else - { - throw new InvalidOperationException( "Invocation must return Task or Task" ); - } + + throw new InvalidOperationException( "Invocation must return Task or Task" ); } - protected override Expression BuildStateMachine() + protected override void ConfigureStateMachine( StateMachineBuilder builder ) { - var builder = new StateMachineBuilder(); - var taskType = _invocationExpression.Type; - - if ( taskType == typeof(Task) ) - { - builder.AddTaskBlock( _invocationExpression ); // Await the single Task - } - else if ( taskType.IsGenericType && taskType.GetGenericTypeDefinition() == typeof(Task<>) ) - { - builder.AddTaskResultBlock( _invocationExpression ); // Await the single Task with a result - } - - return builder.Build(); // Return the built state machine + var block = Block( _invocationExpression ); + builder.GenerateMoveNextMethod( block ); } } - public static partial class AsyncExpression { public static AsyncBaseExpression InvokeAsync( LambdaExpression lambdaExpression, params Expression[] arguments ) diff --git a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs index 785e645..b46034a 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs @@ -18,31 +18,19 @@ protected override Type GetFinalResultType() { return typeof(void); // No result to return } - else if ( _methodCallExpression.Type.IsGenericType && _methodCallExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) ) + + if ( _methodCallExpression.Type.IsGenericType && _methodCallExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) ) { return _methodCallExpression.Type.GetGenericArguments()[0]; // Return T from Task } - else - { - throw new InvalidOperationException( "Method call must return Task or Task" ); - } + + throw new InvalidOperationException( "Method call must return Task or Task" ); } - protected override Expression BuildStateMachine() + protected override void ConfigureStateMachine( StateMachineBuilder builder ) { - var builder = new StateMachineBuilder(); - var taskType = _methodCallExpression.Type; - - if ( taskType == typeof(Task) ) - { - builder.AddTaskBlock( _methodCallExpression ); // Await the single Task - } - else if ( taskType.IsGenericType && taskType.GetGenericTypeDefinition() == typeof(Task<>) ) - { - builder.AddTaskResultBlock( _methodCallExpression ); // Await the single Task with a result - } - - return builder.Build(); // Return the built state machine + var block = Block( _methodCallExpression ); + builder.GenerateMoveNextMethod( block ); } } diff --git a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs index 76f98d4..5018093 100644 --- a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs +++ b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs @@ -1,127 +1,155 @@ using System.Linq.Expressions; +using System.Reflection; +using System.Reflection.Emit; using System.Runtime.CompilerServices; namespace Hyperbee.AsyncExpressions; +// ReSharper disable once InconsistentNaming +internal interface VoidResult; + public class StateMachineBuilder { - private readonly List _blocks = []; - private readonly List _stateVariables = []; - private readonly ParameterExpression _state = Expression.Variable( typeof(int), "state" ); - private readonly ParameterExpression _builder = Expression.Variable( typeof(AsyncTaskMethodBuilder), "builder" ); - - public StateMachineBuilder() + private readonly TypeBuilder _typeBuilder; + private readonly FieldBuilder _stateField; + private readonly FieldBuilder _builderField; + private readonly FieldBuilder _finalResultField; + private readonly MethodBuilder _moveNextMethod; + private readonly FieldBuilder _proxyField; + + public StateMachineBuilder( ModuleBuilder moduleBuilder, string typeName ) { - _stateVariables.Add( _state ); - _stateVariables.Add( _builder ); + // Define a new type that implements IAsyncStateMachine + _typeBuilder = moduleBuilder.DefineType( typeName, TypeAttributes.Public, typeof( object ), [typeof( IAsyncStateMachine )] ); + _stateField = _typeBuilder.DefineField( "_state", typeof( int ), FieldAttributes.Private ); + _builderField = _typeBuilder.DefineField( "_builder", typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Private ); + _finalResultField = _typeBuilder.DefineField( "_finalResult", typeof( TResult ), FieldAttributes.Private ); + + // Define a constructor for the state machine type + var constructor = _typeBuilder.DefineConstructor( MethodAttributes.Public, CallingConventions.Standard, Type.EmptyTypes ); + var ilGenerator = constructor.GetILGenerator(); + ilGenerator.Emit( OpCodes.Ldarg_0 ); + ilGenerator.Emit( OpCodes.Call, typeof( object ).GetConstructor( Type.EmptyTypes )! ); + ilGenerator.Emit( OpCodes.Ret ); + + // Define the MoveNext method that will contain the state machine logic + _moveNextMethod = _typeBuilder.DefineMethod( nameof( IAsyncStateMachine.MoveNext ), MethodAttributes.Public | MethodAttributes.Virtual, typeof( void ), Type.EmptyTypes ); + + // Define a field to store a proxy for the state machine, used for continuation + _proxyField = _typeBuilder.DefineField( "_stateMachineProxy", typeof( StateMachineProxy ), FieldAttributes.Private ); } - public StateMachineBuilder AddBlock( Expression blockExpression ) + public void GenerateMoveNextMethod( BlockExpression reducedBlock ) { - _blocks.Add( blockExpression ); - return this; - } + // Section: Initialization + // Use variables from reducedBlock to ensure variable scope is managed correctly. + var variables = reducedBlock.Variables; - public StateMachineBuilder AddTaskBlock( Expression taskExpression ) - { - // Create the necessary awaiter and assign it to a variable - var awaiterType = typeof(TaskAwaiter); - var awaiter = Expression.Variable( awaiterType, "awaiter" ); - _stateVariables.Add( awaiter ); - - // Generate the block that awaits the task - var assignAwaiter = Expression.Assign( awaiter, Expression.Call( Expression.Convert( taskExpression, typeof(Task) ), nameof(Task.GetAwaiter), null ) ); - var isCompleted = Expression.Property( awaiter, nameof(TaskAwaiter.IsCompleted) ); - - var setState = Expression.Assign( _state, Expression.Constant( _blocks.Count ) ); - - var onCompleted = Expression.Call( - _builder, - nameof(AsyncTaskMethodBuilder.AwaitUnsafeOnCompleted), - [awaiter.Type, typeof(StateMachineBuilder)], - awaiter, - Expression.Constant( this ) - ); - - var block = Expression.Block( - assignAwaiter, - Expression.IfThenElse( - isCompleted, - Expression.Empty(), - Expression.Block( setState, onCompleted, Expression.Return( Expression.Label( typeof(void) ) ) ) - ) - ); + // List to hold all expressions that make up the state machine's body + var bodyExpressions = new List + { + // Initialize the AsyncTaskMethodBuilder field + Expression.Assign(Expression.Field(Expression.Constant(this), _builderField), + Expression.Call(typeof(AsyncTaskMethodBuilder), nameof(AsyncTaskMethodBuilder.Create), null)) + }; - _blocks.Add( block ); - return this; - } + // Section: State Handling + // Iterate through each block to define state transitions + var blocks = reducedBlock.Expressions; + for ( var i = 0; i < blocks.Count; i++ ) + { + var blockExpr = blocks[i]; + + // Define the types for the ConfiguredTaskAwaitable and its awaiter + var configuredTaskAwaitableType = typeof( ConfiguredTaskAwaitable<> ).MakeGenericType( typeof( TResult ) ); + var configuredTaskAwaiterType = configuredTaskAwaitableType.GetNestedType( "ConfiguredTaskAwaiter" ); + + // Define a field to hold the awaiter for this state + var awaiterField = _typeBuilder.DefineField( $"_awaiter_{i}", configuredTaskAwaiterType!, FieldAttributes.Private ); + + // Check if the current state matches + var stateCheck = Expression.Equal( Expression.Field( Expression.Constant( this ), _stateField ), Expression.Constant( i ) ); + + // Assign the awaiter + var assignAwaiter = Expression.Assign( + Expression.Field( Expression.Constant( this ), awaiterField ), + Expression.Call( + Expression.Call( blockExpr, nameof( Task.ConfigureAwait ), null, Expression.Constant( false ) ), + nameof( ConfiguredTaskAwaitable.GetAwaiter ), null ) + ); + + // Create the StateMachineProxy instance and assign it + var stateMachineProxy = Expression.New( typeof( StateMachineProxy ).GetConstructor( [typeof( IAsyncStateMachine )] )!, Expression.Constant( this ) ); + var assignProxy = Expression.Assign( Expression.Field( Expression.Constant( this ), _proxyField ), stateMachineProxy ); + + // Setup continuation with the builder, using the awaiter and the state machine proxy + var setupContinuation = Expression.Call( + Expression.Field( Expression.Constant( this ), _builderField ), + nameof( AsyncTaskMethodBuilder.AwaitUnsafeOnCompleted ), + [configuredTaskAwaiterType, typeof( IAsyncStateMachine )], + Expression.Field( Expression.Constant( this ), awaiterField ), + Expression.Field( Expression.Constant( this ), _proxyField ) + ); + + // Move to the next state + var moveToNextState = Expression.Assign( Expression.Field( Expression.Constant( this ), _stateField ), Expression.Constant( i + 1 ) ); + + // Section: State Execution Logic + // Check if the task is completed or needs to await + var ifNotCompleted = Expression.IfThenElse( + Expression.IsFalse( Expression.Property( Expression.Field( Expression.Constant( this ), awaiterField ), nameof( TaskAwaiter.IsCompleted ) ) ), + Expression.Block( assignAwaiter, assignProxy, setupContinuation, Expression.Return( Expression.Label( typeof( void ) ) ) ), + Expression.Block( assignAwaiter, moveToNextState ) + ); + + // Add the state check and logic to the body expressions + bodyExpressions.Add( Expression.IfThen( stateCheck, ifNotCompleted ) ); + } - public StateMachineBuilder AddTaskResultBlock( Expression taskExpression ) - { - // Determine the result type of the task - var resultType = taskExpression.Type.GetGenericArguments()[0]; - var awaiterType = typeof(TaskAwaiter<>).MakeGenericType( resultType ); - var awaiter = Expression.Variable( awaiterType, "awaiter" ); - _stateVariables.Add( awaiter ); - - // Generate the block that awaits the task result - var assignAwaiter = Expression.Assign( awaiter, Expression.Call( Expression.Convert( taskExpression, typeof(Task<>).MakeGenericType( resultType ) ), nameof(Task.GetAwaiter), null ) ); - var isCompleted = Expression.Property( awaiter, nameof(TaskAwaiter.IsCompleted) ); - - var setState = Expression.Assign( _state, Expression.Constant( _blocks.Count ) ); - var onCompleted = Expression.Call( - _builder, - nameof(AsyncTaskMethodBuilder.AwaitUnsafeOnCompleted), - [awaiter.Type, typeof(StateMachineBuilder)], - awaiter, - Expression.Constant( this ) + // Section: Final State + // Set the final result of the async operation + var setResult = Expression.Call( + Expression.Field( Expression.Constant( this ), _builderField ), + nameof( AsyncTaskMethodBuilder.SetResult ), + null, + Expression.Field( Expression.Constant( this ), _finalResultField ) ); - var block = Expression.Block( - assignAwaiter, - Expression.IfThenElse( - isCompleted, - Expression.Empty(), - Expression.Block( setState, onCompleted, Expression.Return( Expression.Label( typeof(void) ) ) ) - ) - ); + bodyExpressions.Add( setResult ); - _blocks.Add( block ); - return this; + // Section: Emit and Compile + // Include the variables in the block to maintain their scope and compile the MoveNext method + var stateMachineBody = Expression.Block( variables, bodyExpressions ); + EmitCompileToMethod( stateMachineBody, _moveNextMethod ); } - public Expression>> Build() + private void EmitCompileToMethod( Expression stateMachineBody, MethodBuilder methodBuilder ) { - // Final result variable to hold the outcome of the last block - var finalResult = Expression.Variable( typeof(TResult), "finalResult" ); - _stateVariables.Add( finalResult ); - - // Generate the state machine body - var body = new List - { - // Add the initial state of the builder - Expression.Assign( _builder, Expression.Call( typeof(AsyncTaskMethodBuilder), nameof(AsyncTaskMethodBuilder.Create), null ) ) - }; + // compile the generated state machine into a method - // Add each state machine block - for ( var i = 0; i < _blocks.Count; i++ ) - { - var block = _blocks[i]; + var lambda = Expression.Lambda( stateMachineBody ); + var compiledLambda = lambda.Compile(); + var ilGenerator = methodBuilder.GetILGenerator(); - // Check the current state - var condition = Expression.Equal( _state, Expression.Constant( i ) ); - var ifStateMatches = Expression.IfThen( condition, block ); + ilGenerator.Emit( OpCodes.Ldarg_0 ); + ilGenerator.Emit( OpCodes.Call, compiledLambda.Method ); + ilGenerator.Emit( OpCodes.Ret ); + } - body.Add( ifStateMatches ); - } + public Type CreateStateMachineType() + { + // finalize and create the state machine type + _typeBuilder.DefineMethodOverride( _moveNextMethod, typeof( IAsyncStateMachine ).GetMethod( nameof( IAsyncStateMachine.MoveNext ) )! ); + return _typeBuilder.CreateTypeInfo().AsType(); + } +} - // Set the final result and return - body.Add( Expression.Assign( finalResult, Expression.Default( typeof(TResult) ) ) ); - body.Add( Expression.Call( _builder, nameof(AsyncTaskMethodBuilder.SetResult), null, finalResult ) ); +// Proxy class to delegate state transitions in the state machine +public sealed class StateMachineProxy( IAsyncStateMachine stateMachine ) : IAsyncStateMachine +{ + private IAsyncStateMachine _innerStateMachine = stateMachine ?? throw new ArgumentNullException( nameof( stateMachine ) ); - var stateMachineBody = Expression.Block( _stateVariables, body ); + public void MoveNext() => _innerStateMachine.MoveNext(); - // Return the lambda expression representing the state machine - return Expression.Lambda>>( stateMachineBody, [] ); - } + public void SetStateMachine( IAsyncStateMachine stateMachine ) => _innerStateMachine = stateMachine; } diff --git a/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs b/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs index 205047a..ff9a9ad 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs @@ -71,7 +71,6 @@ public void TestAsyncBlock_WithoutAwait_ThrowsException() var asyncBlock = AsyncExpression.BlockAsync( assignExpr1, assignExpr2, assertExpr ); // Act - // This should throw an InvalidOperationException due to lack of AwaitExpression asyncBlock.Reduce(); }