From 1f057e759f3da35a66e4888406142932137f300f Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Mon, 9 Sep 2024 12:25:23 -0700 Subject: [PATCH] Correct MoveNext state machine generation --- .../AsyncInvocationExpression.cs | 18 +- .../AsyncMethodCallExpression.cs | 20 +- .../AwaitExpression.cs | 35 +- .../Hyperbee.AsyncExpressions.csproj | 1 + .../RoslynStateMachineBuilder.cs | 587 ++++++++++++++++++ .../StateMachineBuilder.cs | 269 ++++---- .../UnitTests.cs | 4 +- 7 files changed, 788 insertions(+), 146 deletions(-) create mode 100644 src/Hyperbee.AsyncExpressions/RoslynStateMachineBuilder.cs diff --git a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs index 5181220..86b84fb 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs @@ -20,21 +20,29 @@ public override Expression Reduce() if ( _isReduced ) return _stateMachine; - _stateMachine = StateMachineBuilder.Create( Block( _invocationExpression ), Type, createRunner: true ); + var resultType = ResultType( _invocationExpression.Type ); + + _stateMachine = StateMachineBuilder.Create( Block( _invocationExpression ), resultType, createRunner: true ); _isReduced = true; return _stateMachine; + + static Type ResultType( Type returnType ) + { + return returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>) + ? returnType.GetGenericArguments()[0] + : typeof(void); + } } public override Type Type { get { - var returnType = _invocationExpression.Type; + if ( !_isReduced ) + Reduce(); - return IsTask( returnType ) && returnType.IsGenericType - ? returnType.GetGenericArguments()[0] - : typeof(void); + return _stateMachine.Type; } } } diff --git a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs index 05064eb..ba8d669 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs @@ -21,21 +21,29 @@ public override Expression Reduce() if ( _isReduced ) return _stateMachine; - _stateMachine = StateMachineBuilder.Create( Block( _methodCallExpression ), Type, createRunner: true ); + var resultType = ResultType( _methodCallExpression.Type ); + + _stateMachine = StateMachineBuilder.Create( Block( _methodCallExpression ), resultType, createRunner: true ); _isReduced = true; return _stateMachine; - } + static Type ResultType( Type returnType ) + { + return returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>) + ? returnType.GetGenericArguments()[0] + : typeof(void); + } + } + public override Type Type { get { - var returnType = _methodCallExpression.Type; + if ( !_isReduced ) + Reduce(); - return IsTask( returnType ) && returnType.IsGenericType - ? returnType.GetGenericArguments()[0] - : typeof(void); + return _stateMachine.Type; } } } diff --git a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs index 02e8689..b70a6a5 100644 --- a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs @@ -10,6 +10,7 @@ public class AwaitExpression : Expression { private readonly Expression _asyncExpression; private readonly bool _configureAwait; + private readonly Type _resultType; private static readonly MethodInfo AwaitMethod = typeof(AwaitExpression).GetMethod( nameof(Await), BindingFlags.NonPublic | BindingFlags.Static ); private static readonly MethodInfo AwaitResultMethod = typeof(AwaitExpression).GetMethod( nameof(AwaitResult), BindingFlags.NonPublic | BindingFlags.Static ); @@ -18,25 +19,25 @@ internal AwaitExpression( Expression asyncExpression, bool configureAwait ) { _asyncExpression = asyncExpression ?? throw new ArgumentNullException( nameof( asyncExpression ) ); _configureAwait = configureAwait; + _resultType = ResultType( asyncExpression.Type ); } public override ExpressionType NodeType => ExpressionType.Extension; - public override Type Type + public override Type Type => _resultType; + + private Type ResultType( Type taskType ) { - get + if ( ReturnTask ) + return taskType; + + return taskType.IsGenericType switch { - if ( ReturnTask ) - return _asyncExpression.Type; - - return _asyncExpression.Type.IsGenericType switch - { - true when _asyncExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) => _asyncExpression.Type.GetGenericArguments()[0], - false => typeof(void), - _ => throw new InvalidOperationException( $"Unsupported type in {nameof(AwaitExpression)}." ) - }; - } + true when taskType.GetGenericTypeDefinition() == typeof(Task<>) => taskType.GetGenericArguments()[0], + false => typeof(void), + _ => throw new InvalidOperationException( $"Unsupported type in {nameof(AwaitExpression)}." ) + }; } public bool ReturnTask { get; set; } @@ -48,13 +49,11 @@ public override Expression Reduce() if ( ReturnTask ) return _asyncExpression; - // BF - state machine is not being started (code was lost) - - var awaitResult = Call( Type == typeof( void ) + var awaitExpression = Call( _resultType == typeof( void ) ? AwaitMethod - : AwaitResultMethod.MakeGenericMethod( Type ), _asyncExpression, Constant( _configureAwait ) ); + : AwaitResultMethod.MakeGenericMethod( _resultType ), _asyncExpression, Constant( _configureAwait ) ); - return awaitResult; + return awaitExpression; } private static void Await( Task task, bool configureAwait ) @@ -70,7 +69,7 @@ private static T AwaitResult( Task task, bool configureAwait ) private class AwaitExpressionProxy( AwaitExpression node ) { public Expression Target => node._asyncExpression; - public Type ReturnType => node.Type; + public Type ReturnType => node._resultType; } } diff --git a/src/Hyperbee.AsyncExpressions/Hyperbee.AsyncExpressions.csproj b/src/Hyperbee.AsyncExpressions/Hyperbee.AsyncExpressions.csproj index fc9b0c2..d54a004 100644 --- a/src/Hyperbee.AsyncExpressions/Hyperbee.AsyncExpressions.csproj +++ b/src/Hyperbee.AsyncExpressions/Hyperbee.AsyncExpressions.csproj @@ -37,6 +37,7 @@ + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/Hyperbee.AsyncExpressions/RoslynStateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/RoslynStateMachineBuilder.cs new file mode 100644 index 0000000..0daecb9 --- /dev/null +++ b/src/Hyperbee.AsyncExpressions/RoslynStateMachineBuilder.cs @@ -0,0 +1,587 @@ +using System.Linq.Expressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Hyperbee.AsyncExpressions; + +public class RoslynStateMachineBuilder +{ + private BlockExpression _blockSource; + + public void SetSource( BlockExpression blockSource ) + { + _blockSource = blockSource; + } + + // CreateStateMachine method: Generates the state machine struct + public CompilationUnitSyntax CreateStateMachine( string machineName, Type resultType ) + { + // Create namespace for the state machine + var namespaceDeclaration = SyntaxFactory.NamespaceDeclaration( SyntaxFactory.ParseName( "Hyperbee.AsyncExpressions" ) ) + .NormalizeWhitespace(); + + // Create struct declaration: public struct StateMachineType : IAsyncStateMachine + var structDeclaration = SyntaxFactory.StructDeclaration( machineName ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ) + .AddBaseListTypes( + SyntaxFactory.SimpleBaseType( SyntaxFactory.ParseTypeName( "IAsyncStateMachine" ) ) + ); + + // Add fields: public int _state; public TResult _finalResult; + var stateField = CreateField( "_state", "int", SyntaxKind.PublicKeyword ); + var finalResultField = CreateField( "_finalResult", "TResult", SyntaxKind.PublicKeyword ); + + structDeclaration = structDeclaration.AddMembers( stateField, finalResultField ); + + // Add builder field with initialization + var resultTypeSyntax = BuildTypeSyntax( resultType ); + var builderField = CreateBuilderField( resultTypeSyntax ); + structDeclaration = structDeclaration.AddMembers( builderField ); + + // Add state fields + var memberFields = CreateStateFields( _blockSource ); + structDeclaration = structDeclaration.AddMembers( memberFields ); + + // Add constructor + var constructor = CreateConstructor( machineName ); + structDeclaration = structDeclaration.AddMembers( constructor ); + + // Add MoveNext method + var moveNextMethod = CreateMoveNextMethod( _blockSource ); + structDeclaration = structDeclaration.AddMembers( moveNextMethod ); + + // Add SetStateMachine method + var setStateMachineMethod = CreateSetStateMachineMethod(); + structDeclaration = structDeclaration.AddMembers( setStateMachineMethod ); + + // Add CreateStateMachineRunnerMethod method + var stateMachineRunnerMethod = CreateStateMachineRunnerMethod(); + structDeclaration = structDeclaration.AddMembers( stateMachineRunnerMethod ); + + + // Return the final compilation unit + return SyntaxFactory.CompilationUnit() + .AddUsings( + SyntaxFactory.UsingDirective( SyntaxFactory.ParseName( "System" ) ), + SyntaxFactory.UsingDirective( SyntaxFactory.ParseName( "System.Runtime.CompilerServices" ) ), + SyntaxFactory.UsingDirective( SyntaxFactory.ParseName( "System.Threading.Tasks" ) ) + ) + .AddMembers( namespaceDeclaration.AddMembers( structDeclaration ) ) + .NormalizeWhitespace(); + } + + // Build method stub for later compilation + public void Build() + { + // This method will be responsible for compiling the generated syntax tree + // to produce the final assembly. + } + + // Helper to create fields for the state machine struct + private FieldDeclarationSyntax CreateField( string fieldName, string fieldType, SyntaxKind accessibility ) + { + return SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( SyntaxFactory.ParseTypeName( fieldType ) ) + .AddVariables( SyntaxFactory.VariableDeclarator( fieldName ) ) + ) + .AddModifiers( SyntaxFactory.Token( accessibility ) ); + } + + // Constructor: public StateMachineType() { } + private ConstructorDeclarationSyntax CreateConstructor( string typeName ) + { + return SyntaxFactory.ConstructorDeclaration( typeName ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ) + .WithBody( SyntaxFactory.Block( + // Initialize state + SyntaxFactory.ExpressionStatement( + SyntaxFactory.AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.ThisExpression(), + SyntaxFactory.IdentifierName( "_state" ) + ), + SyntaxFactory.LiteralExpression( SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal( 0 ) ) + ) + ) + ) ); + } + + // CreateFieldDeclaration method: Initializes _builder directly in the field declaration + private FieldDeclarationSyntax CreateBuilderField( TypeSyntax resultTypeSyntax ) + { + /* + _builder initialized in the field declaration + private AsyncTaskMethodBuilder _builder = AsyncTaskMethodBuilder.Create(); // Example where TResult is an int + */ + + return SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( + SyntaxFactory.GenericName( "AsyncTaskMethodBuilder" ) + .WithTypeArgumentList( + SyntaxFactory.TypeArgumentList( + SyntaxFactory.SingletonSeparatedList( resultTypeSyntax ) + ) + ) + ).AddVariables( + SyntaxFactory.VariableDeclarator( "_builder" ) + .WithInitializer( + SyntaxFactory.EqualsValueClause( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.GenericName( "AsyncTaskMethodBuilder" ) + .WithTypeArgumentList( + SyntaxFactory.TypeArgumentList( + SyntaxFactory.SingletonSeparatedList( resultTypeSyntax ) + ) + ), + SyntaxFactory.IdentifierName( "Create" ) + ) + ) + ) + ) + ) + ).AddModifiers( SyntaxFactory.Token( SyntaxKind.PrivateKeyword ) ); + } + + private MemberDeclarationSyntax[] CreateStateFields( BlockExpression block ) + { + /* + This method generates fields equivalent to the following C# code: + + // Fields for variables (e.g., _result_0, _result_2) + private int _result_0; // Stores result from state 0 + private int _result_2; // Stores result from state 2 + + // Fields for awaiters (e.g., _awaiter_0, _awaiter_1, _awaiter_2) + private TaskAwaiter _awaiter_0; // Awaiter for state 0 (Task) + private TaskAwaiter _awaiter_1; // Awaiter for state 1 (Task) + private TaskAwaiter _awaiter_2; // Awaiter for state 2 (Task) + */ + + var fields = new List(); + + // Create fields for variables + foreach ( var variable in block.Variables ) + { + var field = SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( SyntaxFactory.ParseTypeName( variable.Type.Name ) ) + .AddVariables( SyntaxFactory.VariableDeclarator( $"_{variable.Name}" ) ) + ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ); + + fields.Add( field ); + } + + // Create fields for awaiters, one for each child block + for ( var i = 0; i < block.Expressions.Count; i++ ) + { + var awaiterField = SyntaxFactory.FieldDeclaration( + SyntaxFactory.VariableDeclaration( SyntaxFactory.ParseTypeName( "System.Runtime.CompilerServices.TaskAwaiter" ) ) + .AddVariables( SyntaxFactory.VariableDeclarator( $"_awaiter_{i}" ) ) + ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ); + + fields.Add( awaiterField ); + } + + return fields.Cast().ToArray(); // Return the list of field declarations + } + + // MoveNext method: Generates logic for state transitions based on the expression block + private MethodDeclarationSyntax CreateMoveNextMethod( BlockExpression block ) + { + /* + This method generates the equivalent of the following C# code for handling state transitions: + + public void MoveNext() + { + try + { + if (_state == 0) + { + var task0 = ChildBlock0(); // Call child block 0 which returns Task + _awaiter_0 = task0.ConfigureAwait(false).GetAwaiter(); + + if (!_awaiter_0.IsCompleted) + { + _builder.AwaitUnsafeOnCompleted(ref _awaiter_0, ref this); + return; + } + + // Store the result of state 0 + _result_0 = _awaiter_0.GetResult(); + _state = 1; + } + + if (_state == 1) + { + var task1 = ChildBlock1(); // Call child block 1 which returns Task + _awaiter_1 = task1.ConfigureAwait(false).GetAwaiter(); + + if (!_awaiter_1.IsCompleted) + { + _builder.AwaitUnsafeOnCompleted(ref _awaiter_1, ref this); + return; + } + + _awaiter_1.GetResult(); + _state = 2; + } + + if (_state == 2) + { + var task2 = ChildBlock2(_result_0); // Pass result from state 0 to child block 2 + _awaiter_2 = task2.ConfigureAwait(false).GetAwaiter(); + + if (!_awaiter_2.IsCompleted) + { + _builder.AwaitUnsafeOnCompleted(ref _awaiter_2, ref this); + return; + } + + _result_2 = _awaiter_2.GetResult(); // Store result from state 2 + _state = 3; + } + + if (_state == 3) + { + _finalResult = _result_0 + _result_2; // Sum results of state 0 and state 2 + _builder.SetResult(_finalResult); + } + } + catch (Exception ex) + { + _builder.SetException(ex); + } + } + */ + var statements = new List(); + + // Loop over child blocks and create states + for ( var i = 0; i < block.Expressions.Count; i++ ) + { + // Compile the child block into a delegate that can be invoked at runtime + var compiledBlock = CompileSubBlock( block.Expressions[i] ); + + // Check current state + var stateCheck = SyntaxFactory.IfStatement( + SyntaxFactory.BinaryExpression( SyntaxKind.EqualsExpression, + SyntaxFactory.IdentifierName( "_state" ), + SyntaxFactory.LiteralExpression( SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal( i ) ) + ), + SyntaxFactory.Block( + // Invoke the compiled child block to get the task for this state + GenerateTaskInvocation( i, compiledBlock ), + + // Generate state block logic + GenerateStateBlock( i ) + ) + ); + + statements.Add( stateCheck ); + } + + // Wrap all states in a try-catch block and return the method + return SyntaxFactory.MethodDeclaration( + SyntaxFactory.PredefinedType( SyntaxFactory.Token( SyntaxKind.VoidKeyword ) ), + "MoveNext" + ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ) + .WithBody( SyntaxFactory.Block( GenerateTryCatchBlock( statements ) ) ); + + // ---- Local Functions ---- + + // Generates the Try-Catch block that wraps the state transitions + static TryStatementSyntax GenerateTryCatchBlock( List stateStatements ) + { + return SyntaxFactory.TryStatement( + SyntaxFactory.Block( stateStatements ), // Add all state transitions + SyntaxFactory.SingletonList( + SyntaxFactory.CatchClause() + .WithCatchKeyword( SyntaxFactory.Token( SyntaxKind.CatchKeyword ) ) + .WithDeclaration( + SyntaxFactory.CatchDeclaration( + SyntaxFactory.ParseTypeName( "Exception" ) + ).WithIdentifier( SyntaxFactory.Identifier( "ex" ) ) + ).WithBlock( + SyntaxFactory.Block( + SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( "_builder" ), + SyntaxFactory.IdentifierName( "SetException" ) + ) + ).WithArgumentList( + SyntaxFactory.ArgumentList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Argument( SyntaxFactory.IdentifierName( "ex" ) ) + ) + ) + ) + ) + ) + ) + ), + null // no Finally clause + ); + } + + // Local function to compile a sub-block into a delegate + static Delegate CompileSubBlock( Expression expression ) + { + return Expression.Lambda( expression ).Compile(); + } + + // Local function to generate the logic for invoking a compiled sub-block + static StatementSyntax GenerateTaskInvocation( int i, Delegate compiledBlock ) + { + return SyntaxFactory.LocalDeclarationStatement( + SyntaxFactory.VariableDeclaration( SyntaxFactory.IdentifierName( "var" ) ) + .AddVariables( + SyntaxFactory.VariableDeclarator( $"task{i}" ) + .WithInitializer( + SyntaxFactory.EqualsValueClause( + SyntaxFactory.InvocationExpression( SyntaxFactory.IdentifierName( compiledBlock.Method.Name ) ) + ) + ) + ) + ); + } + + // Local function to handle the state check and task awaiting + static BlockSyntax GenerateStateBlock( int i ) + { + return SyntaxFactory.Block( + // _awaiter# = task.ConfigureAwait(false).GetAwaiter(); + SyntaxFactory.ExpressionStatement( + SyntaxFactory.AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + SyntaxFactory.IdentifierName( $"_awaiter_{i}" ), + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( $"task{i}" ), + SyntaxFactory.IdentifierName( "ConfigureAwait" ) + ) + ).WithArgumentList( + SyntaxFactory.ArgumentList( + SyntaxFactory.SingletonSeparatedList( SyntaxFactory.Argument( + SyntaxFactory.LiteralExpression( SyntaxKind.FalseLiteralExpression ) + ) ) + ) + ), + SyntaxFactory.IdentifierName( "GetAwaiter" ) + ) + ) + ) + ), + + // if (!_awaiter#.IsCompleted) + SyntaxFactory.IfStatement( + SyntaxFactory.PrefixUnaryExpression( SyntaxKind.LogicalNotExpression, + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( $"_awaiter_{i}" ), + SyntaxFactory.IdentifierName( "IsCompleted" ) + ) + ), + SyntaxFactory.Block( + // _builder.AwaitUnsafeOnCompleted(ref _awaiter#, ref this) + SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( "_builder" ), + SyntaxFactory.IdentifierName( "AwaitUnsafeOnCompleted" ) + ) + ).WithArgumentList( + SyntaxFactory.ArgumentList( SyntaxFactory.SeparatedList( new[] + { + SyntaxFactory.Argument( SyntaxFactory.IdentifierName( $"_awaiter_{i}" ) ) + .WithRefOrOutKeyword( SyntaxFactory.Token( SyntaxKind.RefKeyword ) ), + SyntaxFactory.Argument( SyntaxFactory.IdentifierName( "this" ) ) + .WithRefOrOutKeyword( SyntaxFactory.Token( SyntaxKind.RefKeyword ) ) + } ) ) + ) + ), + SyntaxFactory.ReturnStatement() // Return to await completion + ) + ), + + // _awaiter#.GetResult(); + SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( $"_awaiter_{i}" ), + SyntaxFactory.IdentifierName( "GetResult" ) + ) + ) + ), + + // _state = i + 1; + SyntaxFactory.ExpressionStatement( + SyntaxFactory.AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + SyntaxFactory.IdentifierName( "_state" ), + SyntaxFactory.LiteralExpression( SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal( i + 1 ) ) + ) + ) + ); + } + } + + // SetStateMachine method: public void SetStateMachine(IAsyncStateMachine stateMachine) + private MethodDeclarationSyntax CreateSetStateMachineMethod() + { + /* + public void SetStateMachine(IAsyncStateMachine stateMachine) + { + _builder.SetStateMachine(stateMachine); + } + */ + return SyntaxFactory.MethodDeclaration( SyntaxFactory.PredefinedType( SyntaxFactory.Token( SyntaxKind.VoidKeyword ) ), "SetStateMachine" ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PublicKeyword ) ) + .AddParameterListParameters( + SyntaxFactory.Parameter( SyntaxFactory.Identifier( "stateMachine" ) ) + .WithType( SyntaxFactory.ParseTypeName( "IAsyncStateMachine" ) ) + ) + .WithBody( SyntaxFactory.Block( + SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( "_builder" ), + SyntaxFactory.IdentifierName( "SetStateMachine" ) + ) + ) + .WithArgumentList( + SyntaxFactory.ArgumentList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Argument( SyntaxFactory.IdentifierName( "stateMachine" ) ) + ) + ) + ) + ) + ) ); + } + + private MethodDeclarationSyntax CreateStateMachineRunnerMethod() + { + // Create the variable declaration: var stateMachine = new StateMachine(); + var stateMachineDeclaration = SyntaxFactory.LocalDeclarationStatement( + SyntaxFactory.VariableDeclaration( SyntaxFactory.IdentifierName( "var" ) ) + .AddVariables( + SyntaxFactory.VariableDeclarator( SyntaxFactory.Identifier( "stateMachine" ) ) + .WithInitializer( + SyntaxFactory.EqualsValueClause( + SyntaxFactory.ObjectCreationExpression( SyntaxFactory.IdentifierName( "StateMachine" ) ) + .WithArgumentList( SyntaxFactory.ArgumentList() ) + ) + ) + ) + ); + + // Create the builder.Start(ref stateMachine) statement + var startMethodCall = SyntaxFactory.ExpressionStatement( + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( "stateMachine" ), + SyntaxFactory.IdentifierName( "_builder" ) + ), + SyntaxFactory.IdentifierName( "Start" ) + ) + ) + .WithArgumentList( + SyntaxFactory.ArgumentList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Argument( SyntaxFactory.IdentifierName( "stateMachine" ) ) + .WithRefOrOutKeyword( SyntaxFactory.Token( SyntaxKind.RefKeyword ) ) + ) + ) + ) + ); + + // Create the return statement: return _builder.Task; + var returnStatement = SyntaxFactory.ReturnStatement( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName( "_builder" ), + SyntaxFactory.IdentifierName( "Task" ) + ) + ); + + // Combine all statements into a method body + return SyntaxFactory.MethodDeclaration( + SyntaxFactory.PredefinedType( SyntaxFactory.Token( SyntaxKind.VoidKeyword ) ), + SyntaxFactory.Identifier( "CreateStateMachineRunner" ) + ) + .AddModifiers( SyntaxFactory.Token( SyntaxKind.PrivateKeyword ) ) + .WithBody( + SyntaxFactory.Block( + stateMachineDeclaration, // Declare stateMachine variable + startMethodCall, // Call _builder.Start(ref stateMachine) + returnStatement // Return _builder.Task + ) + ); + } + + + // Convert a Type to TypeSyntax and typeName + private static TypeSyntax BuildTypeSyntax( Type type ) + { + TypeSyntax typeSyntax; + + if ( type.IsGenericType ) + { + // Handle generic types like Task or List + var genericTypeName = type.GetGenericTypeDefinition().Name.Split( '`' )[0]; // Get the base generic name + var genericArgs = type.GetGenericArguments(); + var genericArgsSyntax = SyntaxFactory.SeparatedList(); + + // Recursively build TypeSyntax for each generic argument + foreach ( var arg in genericArgs ) + { + var argSyntax = BuildTypeSyntax( arg ); + genericArgsSyntax = genericArgsSyntax.Add( argSyntax ); + } + + typeSyntax = SyntaxFactory.GenericName( genericTypeName ) + .WithTypeArgumentList( + SyntaxFactory.TypeArgumentList( genericArgsSyntax ) + ); + } + else if ( type.IsArray ) + { + // Handle array types (e.g., int[], string[]) + var elementType = BuildTypeSyntax( type.GetElementType() ); + typeSyntax = SyntaxFactory.ArrayType( elementType ) + .WithRankSpecifiers( SyntaxFactory.SingletonList( + SyntaxFactory.ArrayRankSpecifier( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.OmittedArraySizeExpression() ) + ) + ) ); + } + else + { + // Handle non-generic, non-array types (e.g., int, string) + typeSyntax = SyntaxFactory.IdentifierName( type.Name ); + } + + return typeSyntax; + } +} + + + diff --git a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs index a76efef..3f25be8 100644 --- a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs +++ b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs @@ -5,6 +5,19 @@ namespace Hyperbee.AsyncExpressions; +public static class Debug +{ + public static void Log( string message ) + { + Console.WriteLine( message ); + } + + public static MethodCallExpression LogCall( string message ) + { + return Expression.Call( typeof( Debug ).GetMethod( "Log" )!, Expression.Constant( message ) ); + } +} + public class StateMachineBuilder { private BlockExpression _blockSource; @@ -45,7 +58,7 @@ public Expression CreateStateMachine( bool createRunner = true ) var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" ); var builderFieldInfo = _stateMachineType.GetField( "_builder" )!; - var setLambdaMethod = _stateMachineType.GetMethod("SetMoveNext")!; + var setLambdaMethod = _stateMachineType.GetMethod( "SetMoveNext" )!; var constructor = _stateMachineType.GetConstructor( Type.EmptyTypes )!; @@ -65,14 +78,34 @@ public Expression CreateStateMachine( bool createRunner = true ) public Expression CreateStateMachineRunner( Expression stateMachineExpression ) { - var stateMachineVariable = Expression.Variable( stateMachineExpression.Type, "stateMachineVariable" ); - var builderFieldInfo = stateMachineExpression.Type.GetField( "_builder" )!; + /* + public Task RunStateMachine( StateMachineType stateMachine ) + { + stateMachine._builder.Start( ref stateMachine ); + return stateMachine._builder.Task; + } + */ + var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" ); + var builderFieldInfo = _stateMachineType.GetField( "_builder" )!; var taskFieldInfo = builderFieldInfo.FieldType.GetProperty( "Task" )!; + var builderField = Expression.Field( stateMachineVariable, builderFieldInfo ); + + var startMethod = typeof( AsyncTaskMethodBuilder<> ) + .MakeGenericType( typeof( TResult ) ) + .GetMethod( "Start" )! + .MakeGenericMethod( _stateMachineType ); + + var callBuilderStart = Expression.Call( + builderField, + startMethod, + stateMachineVariable // expects ref + ); + return Expression.Block( [stateMachineVariable], Expression.Assign( stateMachineVariable, stateMachineExpression ), - Expression.Call( stateMachineVariable, "MoveNext", Type.EmptyTypes ), + callBuilderStart, Expression.Property( Expression.Field( stateMachineVariable, builderFieldInfo ), taskFieldInfo ) ); } @@ -110,7 +143,7 @@ private void CreateStateMachineType( BlockExpression block ) // public void SetStateMachine(IAsyncStateMachine stateMachine) => _builder.SetStateMachine( stateMachine ); // } - _typeBuilder = _moduleBuilder.DefineType( _typeName, TypeAttributes.Public, typeof( object ), [typeof( IAsyncStateMachine )]); + _typeBuilder = _moduleBuilder.DefineType( _typeName, TypeAttributes.Public, typeof( object ), [typeof( IAsyncStateMachine )] ); _typeBuilder.DefineField( "_state", typeof( int ), FieldAttributes.Public ); _builderField = _typeBuilder.DefineField( "_builder", typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Public ); @@ -134,7 +167,7 @@ private void EmitConstructor() // Call the base constructor (object) ilGenerator.Emit( OpCodes.Ldarg_0 ); // this - ilGenerator.Emit( OpCodes.Call, typeof(object).GetConstructor( Type.EmptyTypes )! ); // base() + ilGenerator.Emit( OpCodes.Call, typeof( object ).GetConstructor( Type.EmptyTypes )! ); // base() ilGenerator.Emit( OpCodes.Ret ); // return } @@ -152,7 +185,7 @@ private void EmitBlockFields( BlockExpression block ) _awaiterFields = []; for ( var i = 0; i < block.Expressions.Count; i++ ) { - var expr = block.Expressions[i]; + var expr = block.Expressions[i]; if ( !TryGetAwaiterType( expr, out Type awaiterType ) ) continue; // Not an awaitable expression @@ -173,11 +206,11 @@ private void EmitSetMoveNextMethod() // _moveNextLambda = moveNext; // } - var setMoveNextMethod = _typeBuilder.DefineMethod( - "SetMoveNext", - MethodAttributes.Public | MethodAttributes.Virtual, - typeof(void), - [typeof(Action<>).MakeGenericType( _typeBuilder )] + var setMoveNextMethod = _typeBuilder.DefineMethod( + "SetMoveNext", + MethodAttributes.Public | MethodAttributes.Virtual, + typeof( void ), + [typeof( Action<> ).MakeGenericType( _typeBuilder )] ); var ilGenerator = setMoveNextMethod.GetILGenerator(); @@ -201,7 +234,7 @@ private void EmitMoveNextMethod() var moveNextMethod = _typeBuilder.DefineMethod( "MoveNext", MethodAttributes.Public | MethodAttributes.Virtual, - typeof(void), + typeof( void ), Type.EmptyTypes ); @@ -211,7 +244,7 @@ private void EmitMoveNextMethod() ilGenerator.Emit( OpCodes.Ldfld, _moveNextLambdaField ); // load `_moveNextLambda` ilGenerator.Emit( OpCodes.Ldarg_0 ); // load `this` as lambda argument - var actionObjectType = typeof(Action); + var actionObjectType = typeof( Action ); var invokeMethod = actionObjectType.GetMethod( "Invoke" ); ilGenerator.Emit( OpCodes.Callvirt, invokeMethod! ); // Call Action.Invoke(this) @@ -226,12 +259,12 @@ private void EmitSetStateMachineMethod() // { // _builder.SetStateMachine( stateMachine ); // } - + var setStateMachineMethod = _typeBuilder.DefineMethod( "SetStateMachine", MethodAttributes.Public | MethodAttributes.Virtual, - typeof(void), - [typeof(IAsyncStateMachine)] + typeof( void ), + [typeof( IAsyncStateMachine )] ); var ilGenerator = setStateMachineMethod.GetILGenerator(); @@ -240,53 +273,18 @@ private void EmitSetStateMachineMethod() ilGenerator.Emit( OpCodes.Ldfld, _builderField ); // load `_builder` ilGenerator.Emit( OpCodes.Ldarg_1 ); // Load the `stateMachine` parameter - var setStateMachineOnBuilder = typeof(AsyncTaskMethodBuilder<>) - .MakeGenericType( typeof(TResult) ) - .GetMethod( "SetStateMachine", [typeof(IAsyncStateMachine)] ); + var setStateMachineOnBuilder = typeof( AsyncTaskMethodBuilder<> ) + .MakeGenericType( typeof( TResult ) ) + .GetMethod( "SetStateMachine", [typeof( IAsyncStateMachine )] ); ilGenerator.Emit( OpCodes.Callvirt, setStateMachineOnBuilder! ); ilGenerator.Emit( OpCodes.Ret ); - _typeBuilder.DefineMethodOverride( setStateMachineMethod, - typeof(IAsyncStateMachine).GetMethod( "SetStateMachine" )! + _typeBuilder.DefineMethodOverride( setStateMachineMethod, + typeof( IAsyncStateMachine ).GetMethod( "SetStateMachine" )! ); } - private static bool TryGetAwaiterType( Expression expr, out Type awaiterType ) - { - awaiterType = null; - - switch ( expr ) - { - case MethodCallExpression methodCall when typeof(Task).IsAssignableFrom( methodCall.Type ): - awaiterType = GetAwaiterType( methodCall.Type ); - return true; - - case InvocationExpression invocation when typeof(Task).IsAssignableFrom( invocation.Type ): - awaiterType = GetAwaiterType( invocation.Type ); - return true; - - case not null when typeof(Task).IsAssignableFrom( expr.Type ): - awaiterType = GetAwaiterType( expr.Type ); - return true; - } - - return false; - - static Type GetAwaiterType( Type taskType ) - { - if ( !taskType.IsGenericType ) - return typeof(ConfiguredTaskAwaitable.ConfiguredTaskAwaiter); - - var genericArgument = taskType.GetGenericArguments()[0]; - - if ( genericArgument.FullName == "System.Threading.Tasks.VoidTaskResult" ) - throw new InvalidOperationException( "Task is not supported, are you missing a cast to Task?" ); - - return typeof(ConfiguredTaskAwaitable<>.ConfiguredTaskAwaiter).MakeGenericType( genericArgument ); - } - } - private LambdaExpression CreateMoveNextExpression( BlockExpression block ) { // Example of a typical state-machine: @@ -298,28 +296,31 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) // if (_state == 0) // { // _awaiter1 = task1.ConfigureAwait(false).GetAwaiter(); + // _state = 1; // // if (!_awaiter1.IsCompleted == false) // { // _builder.AwaitUnsafeOnCompleted(ref _awaiter1, this); // return; // } - // - // _awaiter1.GetResult(); - // _state = 1; // } // // if (_state == 1) // { + // _awaiter1.GetResult(); // _awaiter2 = task2.ConfigureAwait(false).GetAwaiter(); - // + // _state = 2; + // if (!_awaiter2.IsCompleted) // { // _builder.AwaitUnsafeOnCompleted(ref _awaiter2, this); // return; // } + // } // - // _awaiter2.GetResult(); + // if (_state == 2) + // { + // _builder.Task.SetResult( _awaiter2.GetResult() ); // } // } // catch (Exception ex) @@ -339,35 +340,32 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) var blocks = block.Expressions; int lastBlockIndex = blocks.Count - 1; - LabelTarget returnLabel = Expression.Label( "ExitMoveNext" ); - // Each block is a state in the state machine - for ( var i = 0; i < blocks.Count; i++ ) + // Iterate through the blocks and handle task-based or non-task states + for ( var i = 0; i <= lastBlockIndex; i++ ) { - // Fix BlockExpression parameters to use fields from state machine var blockExpr = parameterVisitor.Visit( blocks[i] ); var blockReturnType = blockExpr.Type; if ( AsyncBaseExpression.IsTask( blockReturnType ) ) { + // Task-based state generation var awaiterField = GetFieldInfo( _stateMachineType, _awaiterFields[i] ); - var configureAwaitMethod = blockExpr.Type.GetMethod( "ConfigureAwait", [typeof(bool)] )!; var getAwaiterMethod = configureAwaitMethod.ReturnType.GetMethod( "GetAwaiter" ); - // Evaluate the block expression to produce the task - var evaluateTask = blockExpr; - - // Assign the awaiter field (e.g., _awaiterX = task.ConfigureAwait(false).GetAwaiter()) var assignAwaiter = Expression.Assign( Expression.Field( stateMachineInstance, awaiterField ), Expression.Call( - Expression.Call( evaluateTask, configureAwaitMethod, Expression.Constant( false ) ), - getAwaiterMethod! ) + Expression.Call( blockExpr, configureAwaitMethod, Expression.Constant( false ) ), + getAwaiterMethod! + ) ); - // Call AwaitUnsafeOnCompleted when awaiter is not completed + // Set the state before checking IsCompleted + var setStateBeforeAwait = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ); + var awaiterCompletedCheck = Expression.IfThen( Expression.IsFalse( Expression.Property( Expression.Field( stateMachineInstance, awaiterField ), "IsCompleted" ) ), Expression.Block( @@ -378,60 +376,66 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) Expression.Field( stateMachineInstance, awaiterField ), stateMachineInstance ), - Expression.Return( returnLabel ) // Return from MoveNext + Expression.Return( returnLabel ) ) ); - // Handle case when awaiter is completed (i.e., proceed to next state) - - var getResultMethod = awaiterField.FieldType.GetMethod( "GetResult" ); - var getResult = Expression.Call( Expression.Field( stateMachineInstance, awaiterField ), getResultMethod! ); - - var handleCompletedAwaiter = i == lastBlockIndex - ? Expression.Block( - Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ), - Expression.Assign( Expression.Field( stateMachineInstance, finalResultFieldInfo ), getResult ) - ) - : Expression.Block( - getResult, - Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ) - ); - - // Full block for `if ( state == X )` - var stateCheck = Expression.IfThen( - Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i ) ), - Expression.Block( assignAwaiter, awaiterCompletedCheck, handleCompletedAwaiter ) // Execute task handling logic + var stateCheck = Expression.IfThen( + Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i ) ), + Expression.Block( assignAwaiter, setStateBeforeAwait, awaiterCompletedCheck ) ); bodyExpressions.Add( stateCheck ); } - else if ( i == lastBlockIndex ) // final block: non-task + else if ( i == lastBlockIndex ) { - var assignFinalResult = Expression.Assign( - Expression.Field( stateMachineInstance, finalResultFieldInfo ), blockExpr! - ); - - var finalStateCheck = Expression.IfThen( - Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i ) ), - assignFinalResult - ); - - bodyExpressions.Add( finalStateCheck ); + // Handle the last block when it's not a task + var assignFinalResult = Expression.Assign( Expression.Field( stateMachineInstance, finalResultFieldInfo ), blockExpr ); + var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ); + bodyExpressions.Add( Expression.Block( assignFinalResult, incrementState ) ); + } + else + { + throw new InvalidOperationException( $"Non-final block {i} must be a Task." ); } } - // Set the final result - var setResult = Expression.Call( - Expression.Field( stateMachineInstance, buildFieldInfo ), - nameof(AsyncTaskMethodBuilder.SetResult), - null, - Expression.Field( stateMachineInstance, finalResultFieldInfo ) + // Generate the final state + var finalState = Expression.IfThen( + Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( lastBlockIndex + 1 ) ), + Expression.Block( + // If the final block was a task, GetResult from the last awaiter + AsyncBaseExpression.IsTask( blocks[lastBlockIndex].Type ) + ? Expression.Assign( + Expression.Field( stateMachineInstance, finalResultFieldInfo ), + Expression.Call( + Expression.Field( stateMachineInstance, GetFieldInfo( _stateMachineType, _awaiterFields[lastBlockIndex] ) ), + "GetResult", Type.EmptyTypes + ) + ) + : Expression.Empty(), // No-op if not a task + + // Set the final result on the builder + Expression.Call( + Expression.Field( stateMachineInstance, buildFieldInfo ), + nameof(AsyncTaskMethodBuilder.SetResult), + null, + Expression.Field( stateMachineInstance, finalResultFieldInfo ) + ) + ) ); - bodyExpressions.Add( setResult ); + // Mark as completed after the final state logic is executed + var markCompletedState = Expression.Assign( + Expression.Field( stateMachineInstance, "_state" ), + Expression.Constant( -2 ) // Mark the state machine as completed + ); + + bodyExpressions.Add( finalState ); + bodyExpressions.Add( markCompletedState ); bodyExpressions.Add( Expression.Label( returnLabel ) ); - // Return the lambda expression for the method + // Return the generated Lambda Expression representing MoveNext return Expression.Lambda( Expression.Block( bodyExpressions ), stateMachineInstance ); // Helper method to retrieve FieldInfo from the created type @@ -440,14 +444,49 @@ static FieldInfo GetFieldInfo( Type runtimeType, FieldBuilder field ) return runtimeType.GetField( field.Name, BindingFlags.Instance | BindingFlags.Public )!; } } + + private static bool TryGetAwaiterType( Expression expr, out Type awaiterType ) + { + awaiterType = null; + + switch ( expr ) + { + case MethodCallExpression methodCall when typeof( Task ).IsAssignableFrom( methodCall.Type ): + awaiterType = GetAwaiterType( methodCall.Type ); + return true; + + case InvocationExpression invocation when typeof( Task ).IsAssignableFrom( invocation.Type ): + awaiterType = GetAwaiterType( invocation.Type ); + return true; + + case not null when typeof( Task ).IsAssignableFrom( expr.Type ): + awaiterType = GetAwaiterType( expr.Type ); + return true; + } + + return false; + + static Type GetAwaiterType( Type taskType ) + { + if ( !taskType.IsGenericType ) + return typeof( ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ); + + var genericArgument = taskType.GetGenericArguments()[0]; + + if ( genericArgument.FullName == "System.Threading.Tasks.VoidTaskResult" ) + throw new InvalidOperationException( "Task is not supported, are you missing a cast to Task?" ); + + return typeof( ConfiguredTaskAwaitable<>.ConfiguredTaskAwaiter ).MakeGenericType( genericArgument ); + } + } } public static class StateMachineBuilder { private static readonly MethodInfo BuildStateMachineMethod = - typeof(StateMachineBuilder) + typeof( StateMachineBuilder ) .GetMethods( BindingFlags.Public | BindingFlags.Static ) - .First( x => x.Name == nameof(Create) && x.IsGenericMethod ); + .First( x => x.Name == nameof( Create ) && x.IsGenericMethod ); public static Expression Create( BlockExpression source, Type resultType, bool createRunner ) { diff --git a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs index 5ad74b3..1f3cabb 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs @@ -49,7 +49,7 @@ private static int IncrementValue( int a ) private static async Task ThrowExceptionAsync() { await Task.Delay( 50 ); - throw new InvalidOperationException( "Simulated exception" ); + throw new InvalidOperationException( "Simulated exception." ); } private static MethodInfo GetMethodInfo( string name ) @@ -416,7 +416,7 @@ public void TestAsyncExpression_ExceptionHandling( ExpressionKind kind ) } catch ( InvalidOperationException ex ) { - Assert.AreEqual( "Simulated exception", ex.Message, "The exception message should match." ); + Assert.AreEqual( "Simulated exception.", ex.Message, "The exception message should match." ); } } }