From 4498a5a4ac991b13866edfc81f29b2cc6d7d88e9 Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Sun, 22 Sep 2024 08:26:33 -0700 Subject: [PATCH] Refactored StateNode to support Transition types --- .../GotoTransformerVisitor.cs | 617 ++++++++---------- .../GotoTransformerVisitor0.cs | 483 ++++++++++++++ .../ParameterMappingVisitor.cs | 2 +- .../StateMachineBuilder.cs | 1 + src/Hyperbee.AsyncExpressions/StateNode.cs | 37 +- src/Hyperbee.AsyncExpressions/Transition.cs | 95 +++ .../GotoTransformerVisitorTests.cs | 18 +- 7 files changed, 867 insertions(+), 386 deletions(-) create mode 100644 src/Hyperbee.AsyncExpressions/GotoTransformerVisitor0.cs create mode 100644 src/Hyperbee.AsyncExpressions/Transition.cs diff --git a/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor.cs b/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor.cs index 0fc7d18..2f29f5c 100644 --- a/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor.cs +++ b/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor.cs @@ -1,434 +1,371 @@ using System.Linq.Expressions; -using static System.Linq.Expressions.Expression; -namespace Hyperbee.AsyncExpressions; - -public class GotoTransformerVisitor : ExpressionVisitor +namespace Hyperbee.AsyncExpressions { - private readonly List _states = []; - private int _continuationCounter; - private int _labelCounter; - private StateNode _currentState; - private readonly Stack _finalNodes = []; - private readonly Dictionary _labelMappings = []; - - public List Transform( Expression expression ) + public class GotoTransformerVisitor : ExpressionVisitor { - // Initialize the first state (n0) - _currentState = new StateNode( _labelCounter++ ); - _states.Add( _currentState ); + private readonly List _states = []; + private readonly Stack _continueToIndexes = new(); + private readonly Dictionary _labelMappings = new(); - Visit( expression ); + private int _continuationCounter; + private int _labelCounter; - return _states; - } + private int _currentStateIndex; + private StateNode CurrentState => _states[_currentStateIndex]; - protected override Expression VisitBlock( BlockExpression node ) - { - foreach ( var expr in node.Expressions ) + public List Transform( Expression expression ) { - Visit( expr ); - } + _currentStateIndex = InsertState(); - return node; - } + Visit( expression ); - protected override Expression VisitConditional( ConditionalExpression node ) - { - // Always lift Condition to current state - Visit( node.Test ); + return _states; + } - // Push the final node to stack for later convergence - var hasFalse = node.IfFalse is not DefaultExpression; - var ifTrueNode = new StateNode( _labelCounter++ ); - var ifFalseNode = hasFalse ? new StateNode( _labelCounter++ ) : null; - var finalNode = new StateNode( _labelCounter++ ); + private int InsertState() + { + _states.Add( new StateNode( _labelCounter++ ) ); + return _states.Count - 1; + } - _currentState.IfTrue = ifTrueNode; - _currentState.IfFalse = ifFalseNode; + private int InsertState( Expression expression ) + { + var stateIndex = InsertState(); + _currentStateIndex = stateIndex; - _states.Add( finalNode ); - _finalNodes.Push( finalNode ); + Visit( expression ); // Visit may mutate _currentStateIndex + + return stateIndex; + } - // Process IfTrue branch - ProcessBranch( node.IfTrue, ifTrueNode, finalNode ); + private void PushContinueTo( int index ) => _continueToIndexes.Push( index ); - // Process IfFalse branch - if ( hasFalse ) - ProcessBranch( node.IfFalse, ifFalseNode, finalNode ); + private int PopContinueTo() => _continueToIndexes.Pop(); - // Pop the final node and set it as current state - _currentState = _finalNodes.Pop(); + protected override Expression VisitBlock( BlockExpression node ) + { + foreach ( var expr in node.Expressions ) + { + Visit( expr ); + } - return node; + return node; + } - } + protected override Expression VisitConditional( ConditionalExpression node ) + { + Visit( node.Test ); - protected override Expression VisitSwitch( SwitchExpression node ) - { - // Always lift SwitchValue to current state - Visit( node.SwitchValue ); + var currentStateIndex = _currentStateIndex; - var switchNode = _currentState; - switchNode.Cases = []; + var ifTrueIndex = InsertState( node.IfTrue ); + var ifFalseIndex = (node.IfFalse is not DefaultExpression) ? InsertState( node.IfFalse ) : -1; - // Create the final node where all cases will converge - var finalNode = new StateNode( _labelCounter++ ); - _states.Add( finalNode ); - _finalNodes.Push( finalNode ); + var continueToIndex = InsertState(); + PushContinueTo( continueToIndex ); - // Process each case - List cases = []; - foreach ( var switchCase in node.Cases ) - { - var caseNode = new StateNode( _labelCounter++ ); - switchNode.Cases.Add( caseNode ); + var conditionalTransition = new ConditionalTransition + { + IfTrue = _states[ifTrueIndex], + IfFalse = ifFalseIndex >= -1 ? _states[ifFalseIndex] : null, + ContinueTo = _states[continueToIndex] + }; - // Add case label to the state - cases.Add( SwitchCase( Goto( caseNode.Label ), switchCase.TestValues ) ); + _states[currentStateIndex].Transition = conditionalTransition; + _currentStateIndex = PopContinueTo(); - ProcessBranch( switchCase.Body, caseNode, finalNode ); + return node; } - // Handle default case if present - Expression defaultBody = null; - if ( node.DefaultBody != null ) + protected override Expression VisitSwitch( SwitchExpression node ) { - var defaultNode = new StateNode( _labelCounter++ ); + Visit( node.SwitchValue ); - // TODO: Can't use `ProcessBranch` because GoTos are add differently + var currentStateIndex = _currentStateIndex; + var switchTransition = new SwitchTransition(); - _states.Add( defaultNode ); - _currentState = defaultNode; + var continueToIndex = InsertState(); + PushContinueTo( continueToIndex ); - Visit( node.DefaultBody ); - - //_currentState.Expressions.Add( Goto( finalNode.Label ) ); - defaultBody = Goto( finalNode.Label ); - _currentState.Final = finalNode; - } + foreach ( var switchCase in node.Cases ) + { + var caseIndex = InsertState( switchCase.Body ); - var gotoSwitch = Switch( - node.SwitchValue, - defaultBody, - [.. cases] ); - switchNode.Expressions.Add( gotoSwitch ); + _states[caseIndex].Transition = new DefaultTransition + { + ContinueTo = _states[continueToIndex] + }; + switchTransition.CaseNodes.Add( _states[caseIndex] ); + } - // Pop the final node and set it as current state - _currentState = _finalNodes.Pop(); + if ( node.DefaultBody != null ) + { + var defaultIndex = InsertState( node.DefaultBody ); - return node; - } + _states[defaultIndex].Transition = new DefaultTransition + { + ContinueTo = _states[continueToIndex] + }; - protected override Expression VisitTry( TryExpression node ) - { - // Always lift body to current state - Visit( node.Body ); + switchTransition.DefaultNode = _states[defaultIndex]; + } - var tryNode = _currentState; - tryNode.Catches = []; + continueToIndex = PopContinueTo(); - var hasFinally = node.Finally != null; + switchTransition.ContinueTo = _states[continueToIndex]; - // TODO: fault block - //var hasFault = node.Fault != null; - //var faultNode = hasFault ? new StateNode( _labelCounter++ ) : null; - var finallyNode = hasFinally ? new StateNode( _labelCounter++ ) : null; - var finalNode = new StateNode( _labelCounter++ ); + _states[currentStateIndex].Transition = switchTransition; + _currentStateIndex = continueToIndex; - _states.Add( finalNode ); - _finalNodes.Push( finalNode ); + return node; + } - // Process each case - List catches = []; - foreach ( var catchBlock in node.Handlers ) + protected override Expression VisitTry( TryExpression node ) //BF awaits aren't allowed in try-catch-finally. Are we doing too much? { - var catchNode = new StateNode( _labelCounter++ ); - tryNode.Catches.Add( catchNode ); + var currentStateIndex = _currentStateIndex; - // TODO: catchBlock.Filter - // Add case label to the state - // TODO: verify node.Body.Type as the correct type - catches.Add( Catch( catchBlock.Test, Goto( catchNode.Label, node.Body.Type ) ) ); + var tryCatchTransition = new TryCatchTransition(); - ProcessBranch( catchBlock.Body, catchNode, finalNode ); - } + var continueToIndex = InsertState(); + PushContinueTo( continueToIndex ); - // Visit the finally-block, if it exists - Expression finallyBody = null; - if ( finallyNode != null ) - { - var defaultNode = new StateNode( _labelCounter++ ); + var tryIndex = InsertState( node.Body ); + tryCatchTransition.TryNode = _states[tryIndex]; + + foreach ( var catchBlock in node.Handlers ) + { + var catchIndex = InsertState( catchBlock.Body ); + tryCatchTransition.CatchNodes.Add( _states[catchIndex] ); + } - // TODO: Can't use `ProcessBranch` because GoTos are add differently + if ( node.Finally != null ) + { + var finallyIndex = InsertState( node.Finally ); + tryCatchTransition.FinallyNode = _states[finallyIndex]; + } - _states.Add( defaultNode ); - _currentState = defaultNode; + continueToIndex = PopContinueTo(); + + tryCatchTransition.ContinueTo = _states[continueToIndex]; - Visit( node.Finally ); + _states[currentStateIndex].Transition = tryCatchTransition; + _currentStateIndex = continueToIndex; - finallyBody = Goto( finalNode.Label ); - _currentState.Final = finalNode; + return node; } - // Visit the fault-block, if it exists - // Expression faultBody = null; - // if ( faultNode != null ) - // { - // } + protected override Expression VisitExtension( Expression node ) + { + if ( node is not AwaitExpression awaitExpression ) + { + CurrentState.Expressions.Add( node ); + return node; + } - // TODO replace? - var newTry = TryCatchFinally( - node.Body, - finallyBody, - [..catches] - ); - tryNode.Expressions.Add( newTry ); + var currentStateIndex = _currentStateIndex; - // Pop the final node and set it as current state - _currentState = _finalNodes.Pop(); + // awaiter-finally + var continueToIndex = InsertState(); + PushContinueTo( continueToIndex ); - return node; + // awaiter-continuation + var completionStateIndex = InsertState( awaitExpression.Target ); - } + var gotoTransition = new GotoTransition + { + TargetNode = _states[continueToIndex] + }; - protected override Expression VisitLoop( LoopExpression node ) - { - // var loopNode = _currentState; - // - // var breakNode = new StateNode( _labelCounter++ ); // { Label = node.BreakLabel }; - // var finalNode = new StateNode( _labelCounter++ ); - // _states.Add( finalNode ); - // _finalNodes.Push( finalNode ); - // - // var loopBodyNode = new StateNode( _labelCounter++ ); - // _states.Add( loopBodyNode ); - // _currentState = loopBodyNode; - // - // Visit( node.Body ); - // - // _currentState.Expressions.Add( Goto( finalNode.Label ) ); - // _currentState.Final = finalNode; - // - // loopNode.Continue = loopBodyNode; - // loopNode.Break = breakNode; - // breakNode.Final = finalNode; - // - // - // _currentState = _finalNodes.Pop(); - - return node; - } + _states[completionStateIndex].Transition = gotoTransition; + + // awaiter + var awaitTransition = new AwaitTransition + { + ContinuationId = _continuationCounter++, + CompletionNode = _states[completionStateIndex], + ContinueTo = _states[continueToIndex] + }; + + _states[currentStateIndex].Transition = awaitTransition; + + _currentStateIndex = PopContinueTo(); + + // build awaiter + /* + awaiter8 = GetRandom().GetAwaiter(); + if (!awaiter8.IsCompleted) + { + num = (<>1__state = 0); + <>u__1 = awaiter8; +
d__0 stateMachine = this; + <>t__builder.AwaitUnsafeOnCompleted(ref awaiter8, ref stateMachine); + return; + } + goto IL_00fe; + */ + + // build awaiter continuation + /* + awaiter8 = <>u__1; + <>u__1 = default(TaskAwaiter); + num = (<>1__state = -1); + goto IL_00fe; + */ - protected override Expression VisitExtension( Expression node ) - { - if ( node is not AwaitExpression awaitExpression ) - { - _currentState.Expressions.Add( node ); return node; } - var stateId = _continuationCounter++; - var awaitNode = new StateNode( _labelCounter++ ) { ContinuationId = stateId }; - var finalNode = new StateNode( _labelCounter++ ) { ContinuationId = stateId }; - - _currentState.Await = awaitNode; - - _states.Add( finalNode ); - _finalNodes.Push( finalNode ); - - ProcessBranch( awaitExpression.Target, awaitNode, finalNode ); - - // build awaiter - /* - awaiter8 = GetRandom().GetAwaiter(); - if (!awaiter8.IsCompleted) - { - num = (<>1__state = 0); - <>u__1 = awaiter8; -
d__0 stateMachine = this; - <>t__builder.AwaitUnsafeOnCompleted(ref awaiter8, ref stateMachine); - return; - } - goto IL_00fe; - */ - - // build awaiter continue: - /* - awaiter8 = <>u__1; - <>u__1 = default(TaskAwaiter); - num = (<>1__state = -1); - goto IL_00fe; - */ - - // Pop the final node and set it as current state - _currentState = _finalNodes.Pop(); - - return node; - } - - protected override Expression VisitMethodCall( MethodCallExpression node ) - { - foreach ( var nodeArgument in node.Arguments ) + protected override Expression VisitMethodCall( MethodCallExpression node ) { - Visit( nodeArgument ); - } + foreach ( var nodeArgument in node.Arguments ) + { + Visit( nodeArgument ); + } - _currentState.Expressions.Add( node ); + CurrentState.Expressions.Add( node ); + return node; + } - return node; - } + protected override Expression VisitBinary( BinaryExpression node ) + { + CurrentState.Expressions.Add( node ); + return node; + } - protected override Expression VisitBinary( BinaryExpression node ) - { - _currentState.Expressions.Add( node ); - return node; - } + protected override Expression VisitParameter( ParameterExpression node ) + { + CurrentState.Expressions.Add( node ); + return node; + } - protected override Expression VisitParameter( ParameterExpression node ) - { - _currentState.Expressions.Add( node ); - return node; - } + protected override Expression VisitConstant( ConstantExpression node ) + { + CurrentState.Expressions.Add( node ); + return node; + } - protected override Expression VisitConstant( ConstantExpression node ) - { - _currentState.Expressions.Add( node ); - return node; - } + protected override Expression VisitUnary( UnaryExpression node ) + { + if ( node.NodeType != ExpressionType.Throw ) + { + return base.VisitUnary( node ); + } - protected override Expression VisitGoto( GotoExpression node ) - { - // Handle goto if necessary - _currentState.Expressions.Add( node ); + CurrentState.Expressions.Add( node ); + return node; + } - var gotoNode = new StateNode( _labelCounter++ ); - _states.Add( gotoNode ); - _currentState.Goto = gotoNode; - gotoNode.Final = CreateLabelBlock( node.Target ); + protected override Expression VisitGoto( GotoExpression node ) + { + var currentStateIndex = _currentStateIndex; + + var gotoTransition = new GotoTransition + { + TargetNode = _states[GetOrCreateLabelIndex( node.Target )] + }; - return node; - } + _states[currentStateIndex].Transition = gotoTransition; + return node; + } - protected override Expression VisitUnary( UnaryExpression node ) - { - if(node.NodeType == ExpressionType.Throw) + protected override Expression VisitLabel( LabelExpression node ) { - _currentState.Expressions.Add( node ); + var labelIndex = GetOrCreateLabelIndex( node.Target ); + + _states[labelIndex].Transition ??= new LabelTransition(); + return node; } - return base.VisitUnary( node ); - } + private int GetOrCreateLabelIndex( LabelTarget label ) + { + if ( _labelMappings.TryGetValue( label, out var index ) ) + { + return index; + } - protected override Expression VisitLabel( LabelExpression node ) - { - // Create a label state block and map it to the label target - CreateLabelBlock( node.Target ); - return node; - } + index = InsertState(); + _labelMappings[label] = index; - private StateNode CreateLabelBlock( LabelTarget label ) - { - if ( _labelMappings.TryGetValue( label, out var id ) ) - { - return _states.First( x => x.BlockId == id ); + return index; } - var block = new StateNode( _labelCounter++ ); - _labelMappings[label] = block.BlockId; - _states.Add( block ); - return block; - } + public void PrintStateMachine() + { + PrintStateMachine( _states ); + } - private void ProcessBranch( Expression expression, StateNode stateNode, StateNode final ) - { - _states.Add( stateNode ); - _currentState = stateNode; + public static void PrintStateMachine( List states ) + { + foreach ( var state in states ) + { + if ( state == null ) + continue; - Visit( expression ); + var transitionName = state?.Transition?.GetType().Name ?? "Null"; - // TODO: This Add doesn't work for everyone - _currentState.Expressions.Add( Goto( final.Label ) ); - _currentState.Final = final; - } + Console.WriteLine( $"{state.Label.Name}: [{transitionName}]" ); - public void PrintStateMachine() - { - PrintStateMachine( _states ); - } + foreach ( var expr in state.Expressions ) + { + Console.WriteLine( $"\t{ExpressionToString( expr )}" ); + } - public static void PrintStateMachine( List states ) - { - foreach ( var state in states ) - { - Console.WriteLine( $"{state.Label}: {(state.ContinuationId != null ? $" (state: {state.ContinuationId})" : string.Empty)}" ); - foreach ( var expr in state.Expressions ) - { - Console.WriteLine( $"\t{ExpressionToString( expr )}" ); - } + var transition = state.Transition; - if ( state.Cases != null ) - { - foreach ( var caseNode in state.Cases ) + if ( transition != null ) { - Console.WriteLine( $"\tCase -> {caseNode.Label}" ); + switch ( transition ) + { + case ConditionalTransition condNode: + Console.WriteLine( $"\tIfTrue -> {condNode.IfTrue?.BlockId}" ); + Console.WriteLine( $"\tIfFalse -> {condNode.IfFalse?.BlockId}" ); + break; + case SwitchTransition switchNode: + foreach ( var caseNode in switchNode.CaseNodes ) + { + Console.WriteLine( $"\tCase -> {caseNode?.BlockId}" ); + } + + Console.WriteLine( $"\tDefault -> {switchNode.DefaultNode?.BlockId}" ); + break; + case TryCatchTransition tryNode: + Console.WriteLine( $"\tTry -> {tryNode.TryNode?.BlockId}" ); + foreach ( var catchNode in tryNode.CatchNodes ) + { + Console.WriteLine( $"\tCatch -> {catchNode?.BlockId}" ); + } + + Console.WriteLine( $"\tFinally -> {tryNode.FinallyNode?.BlockId}" ); + break; + case AwaitTransition awaitNode: + Console.WriteLine( $"\tAwait -> {awaitNode.CompletionNode?.BlockId}" ); + break; + case GotoTransition gotoNode: + Console.WriteLine( $"\tGoto -> {gotoNode.TargetNode?.BlockId}" ); + break; + + case LabelTransition: + case DefaultTransition: + break; + } } - } - if ( state.Await != null ) - Console.WriteLine( $"\tAwait -> {state.Await.Label}" ); - if ( state.IfTrue != null ) - Console.WriteLine( $"\tIfTrue -> {state.IfTrue.Label}" ); - if ( state.IfFalse != null ) - Console.WriteLine( $"\tIfFalse -> {state.IfFalse.Label}" ); - if ( state.Final != null ) - Console.WriteLine( $"\tFinal -> {state.Final.Label}" ); - if( state.Goto != null ) - Console.WriteLine( $"\tGoto -> {state.Goto.Label}" ); - if ( state.IsTerminal ) - Console.WriteLine( "\tTerminal" ); - Console.WriteLine(); - } - return; + if ( state.Transition?.ContinueTo != null ) + Console.WriteLine( $"\tContinueTo -> {state.Transition.ContinueTo.BlockId}" ); + if ( state.Transition == null ) + Console.WriteLine( "\tTerminal" ); - static string GetBinaryOperator( ExpressionType nodeType ) - { - return nodeType switch - { - ExpressionType.Assign => "=", - ExpressionType.GreaterThan => ">", - ExpressionType.LessThan => "<", - ExpressionType.Add => "+", - ExpressionType.Subtract => "-", - ExpressionType.Multiply => "*", - ExpressionType.Divide => "/", - _ => nodeType.ToString() - }; + Console.WriteLine(); + } } - static string ExpressionToString( Expression expr ) + private static string ExpressionToString( Expression expr ) { - switch ( expr ) - { - case MethodCallExpression m: - var args = string.Join( ", ", m.Arguments.Select( ExpressionToString ) ); - return $"{m.Method.Name}({args})"; - case BinaryExpression b: - return $"{ExpressionToString( b.Left )} {GetBinaryOperator( b.NodeType )} {ExpressionToString( b.Right )}"; - case ParameterExpression p: - return p.Name; - case ConstantExpression c: - return c.Value?.ToString() ?? "empty"; - case GotoExpression g: - return $"goto {g.Target.Name}"; - case UnaryExpression u: - return $"{u.NodeType} {ExpressionToString( u.Operand )}"; - default: - return expr.ToString(); - } + return expr.ToString(); } } } diff --git a/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor0.cs b/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor0.cs new file mode 100644 index 0000000..0b00b96 --- /dev/null +++ b/src/Hyperbee.AsyncExpressions/GotoTransformerVisitor0.cs @@ -0,0 +1,483 @@ +using System.Linq.Expressions; +using static System.Linq.Expressions.Expression; + +namespace Hyperbee.AsyncExpressions; + +public class GotoTransformerVisitor0 : ExpressionVisitor +{ + private readonly List _states = []; + private int _continuationCounter; + private int _labelCounter; + private StateNode0 _currentState; + private readonly Stack _finalNodes = []; + private readonly Dictionary _labelMappings = []; + + public List Transform( Expression expression ) + { + // Initialize the first state (n0) + _currentState = new StateNode0( _labelCounter++ ); + _states.Add( _currentState ); + + Visit( expression ); + + return _states; + } + + protected override Expression VisitBlock( BlockExpression node ) + { + foreach ( var expr in node.Expressions ) + { + Visit( expr ); + } + + return node; + } + + protected override Expression VisitConditional( ConditionalExpression node ) + { + // Always lift Condition to current state + Visit( node.Test ); + + // Push the final node to stack for later convergence + var hasFalse = node.IfFalse is not DefaultExpression; + var ifTrueNode = new StateNode0( _labelCounter++ ); + var ifFalseNode = hasFalse ? new StateNode0( _labelCounter++ ) : null; + var finalNode = new StateNode0( _labelCounter++ ); + + _currentState.IfTrue = ifTrueNode; + _currentState.IfFalse = ifFalseNode; + + _states.Add( finalNode ); + _finalNodes.Push( finalNode ); + + // Process IfTrue branch + ProcessBranch( node.IfTrue, ifTrueNode, finalNode ); + + // Process IfFalse branch + if ( hasFalse ) + ProcessBranch( node.IfFalse, ifFalseNode, finalNode ); + + // Pop the final node and set it as current state + _currentState = _finalNodes.Pop(); + + return node; + + } + + protected override Expression VisitSwitch( SwitchExpression node ) + { + // Always lift SwitchValue to current state + Visit( node.SwitchValue ); + + var switchNode = _currentState; + switchNode.Cases = []; + + // Create the final node where all cases will converge + var finalNode = new StateNode0( _labelCounter++ ); + _states.Add( finalNode ); + _finalNodes.Push( finalNode ); + + // Process each case + List cases = []; + foreach ( var switchCase in node.Cases ) + { + var caseNode = new StateNode0( _labelCounter++ ); + switchNode.Cases.Add( caseNode ); + + // Add case label to the state + cases.Add( SwitchCase( Goto( caseNode.Label ), switchCase.TestValues ) ); + + ProcessBranch( switchCase.Body, caseNode, finalNode ); + } + + // Handle default case if present + Expression defaultBody = null; + if ( node.DefaultBody != null ) + { + var defaultNode = new StateNode0( _labelCounter++ ); + + // TODO: Can't use `ProcessBranch` because GoTos are add differently + + _states.Add( defaultNode ); + _currentState = defaultNode; + + Visit( node.DefaultBody ); + + //_currentState.Expressions.Add( Goto( finalNode.Label ) ); + defaultBody = Goto( finalNode.Label ); + _currentState.Final = finalNode; + } + + var gotoSwitch = Switch( + node.SwitchValue, + defaultBody, + [.. cases] ); + switchNode.Expressions.Add( gotoSwitch ); + + + // Pop the final node and set it as current state + _currentState = _finalNodes.Pop(); + + return node; + } + + protected override Expression VisitTry( TryExpression node ) + { + // Always lift body to current state + Visit( node.Body ); + + var tryNode = _currentState; + tryNode.Catches = []; + + var hasFinally = node.Finally != null; + + // TODO: fault block + //var hasFault = node.Fault != null; + //var faultNode = hasFault ? new StateNode( _labelCounter++ ) : null; + var finallyNode = hasFinally ? new StateNode0( _labelCounter++ ) : null; + var finalNode = new StateNode0( _labelCounter++ ); + + _states.Add( finalNode ); + _finalNodes.Push( finalNode ); + + // Process each case + List catches = []; + foreach ( var catchBlock in node.Handlers ) + { + var catchNode = new StateNode0( _labelCounter++ ); + tryNode.Catches.Add( catchNode ); + + // TODO: catchBlock.Filter + // Add case label to the state + // TODO: verify node.Body.Type as the correct type + catches.Add( Catch( catchBlock.Test, Goto( catchNode.Label, node.Body.Type ) ) ); + + ProcessBranch( catchBlock.Body, catchNode, finalNode ); + } + + // Visit the finally-block, if it exists + Expression finallyBody = null; + if ( finallyNode != null ) + { + var defaultNode = new StateNode0( _labelCounter++ ); + + // TODO: Can't use `ProcessBranch` because GoTos are add differently + + _states.Add( defaultNode ); + _currentState = defaultNode; + + Visit( node.Finally ); + + finallyBody = Goto( finalNode.Label ); + _currentState.Final = finalNode; + } + + // Visit the fault-block, if it exists + // Expression faultBody = null; + // if ( faultNode != null ) + // { + // } + + // TODO replace? + var newTry = TryCatchFinally( + node.Body, + finallyBody, + [..catches] + ); + tryNode.Expressions.Add( newTry ); + + // Pop the final node and set it as current state + _currentState = _finalNodes.Pop(); + + return node; + + } + + protected override Expression VisitLoop( LoopExpression node ) + { + // var loopNode = _currentState; + // + // var breakNode = new StateNode( _labelCounter++ ); // { Label = node.BreakLabel }; + // var finalNode = new StateNode( _labelCounter++ ); + // _states.Add( finalNode ); + // _finalNodes.Push( finalNode ); + // + // var loopBodyNode = new StateNode( _labelCounter++ ); + // _states.Add( loopBodyNode ); + // _currentState = loopBodyNode; + // + // Visit( node.Body ); + // + // _currentState.Expressions.Add( Goto( finalNode.Label ) ); + // _currentState.Final = finalNode; + // + // loopNode.Continue = loopBodyNode; + // loopNode.Break = breakNode; + // breakNode.Final = finalNode; + // + // + // _currentState = _finalNodes.Pop(); + + return node; + } + + protected override Expression VisitExtension( Expression node ) + { + if ( node is not AwaitExpression awaitExpression ) + { + _currentState.Expressions.Add( node ); + return node; + } + + var stateId = _continuationCounter++; + var awaitNode = new StateNode0( _labelCounter++ ) { ContinuationId = stateId }; + var finalNode = new StateNode0( _labelCounter++ ) { ContinuationId = stateId }; + + _currentState.Await = awaitNode; + + _states.Add( finalNode ); + _finalNodes.Push( finalNode ); + + ProcessBranch( awaitExpression.Target, awaitNode, finalNode ); + + // build awaiter + /* + awaiter8 = GetRandom().GetAwaiter(); + if (!awaiter8.IsCompleted) + { + num = (<>1__state = 0); + <>u__1 = awaiter8; +
d__0 stateMachine = this; + <>t__builder.AwaitUnsafeOnCompleted(ref awaiter8, ref stateMachine); + return; + } + goto IL_00fe; + */ + + // build awaiter continue: + /* + awaiter8 = <>u__1; + <>u__1 = default(TaskAwaiter); + num = (<>1__state = -1); + goto IL_00fe; + */ + + // Pop the final node and set it as current state + _currentState = _finalNodes.Pop(); + + return node; + } + + protected override Expression VisitMethodCall( MethodCallExpression node ) + { + foreach ( var nodeArgument in node.Arguments ) + { + Visit( nodeArgument ); + } + + _currentState.Expressions.Add( node ); + + return node; + } + + protected override Expression VisitBinary( BinaryExpression node ) + { + _currentState.Expressions.Add( node ); + return node; + } + + protected override Expression VisitParameter( ParameterExpression node ) + { + _currentState.Expressions.Add( node ); + return node; + } + + protected override Expression VisitConstant( ConstantExpression node ) + { + _currentState.Expressions.Add( node ); + return node; + } + + protected override Expression VisitGoto( GotoExpression node ) + { + // Handle goto if necessary + _currentState.Expressions.Add( node ); + + var gotoNode = new StateNode0( _labelCounter++ ); + _states.Add( gotoNode ); + _currentState.Goto = gotoNode; + gotoNode.Final = CreateLabelBlock( node.Target ); + + return node; + } + + protected override Expression VisitUnary( UnaryExpression node ) + { + if(node.NodeType == ExpressionType.Throw) + { + _currentState.Expressions.Add( node ); + return node; + } + + return base.VisitUnary( node ); + } + + protected override Expression VisitLabel( LabelExpression node ) + { + // Create a label state block and map it to the label target + CreateLabelBlock( node.Target ); + return node; + } + + private StateNode0 CreateLabelBlock( LabelTarget label ) + { + if ( _labelMappings.TryGetValue( label, out var id ) ) + { + return _states.First( x => x.BlockId == id ); + } + + var block = new StateNode0( _labelCounter++ ); + _labelMappings[label] = block.BlockId; + _states.Add( block ); + return block; + } + + private void ProcessBranch( Expression expression, StateNode0 stateNode, StateNode0 final ) + { + _states.Add( stateNode ); + _currentState = stateNode; + + Visit( expression ); + + // TODO: This Add doesn't work for everyone + _currentState.Expressions.Add( Goto( final.Label ) ); + _currentState.Final = final; + } + + public void PrintStateMachine() + { + PrintStateMachine( _states ); + } + + public static void PrintStateMachine( List states ) + { + foreach ( var state in states ) + { + Console.WriteLine( $"{state.Label}: {(state.ContinuationId != null ? $" (state: {state.ContinuationId})" : string.Empty)}" ); + foreach ( var expr in state.Expressions ) + { + Console.WriteLine( $"\t{ExpressionToString( expr )}" ); + } + + if ( state.Cases != null ) + { + foreach ( var caseNode in state.Cases ) + { + Console.WriteLine( $"\tCase -> {caseNode.Label}" ); + } + } + if ( state.Await != null ) + Console.WriteLine( $"\tAwait -> {state.Await.Label}" ); + if ( state.IfTrue != null ) + Console.WriteLine( $"\tIfTrue -> {state.IfTrue.Label}" ); + if ( state.IfFalse != null ) + Console.WriteLine( $"\tIfFalse -> {state.IfFalse.Label}" ); + if ( state.Final != null ) + Console.WriteLine( $"\tFinal -> {state.Final.Label}" ); + if( state.Goto != null ) + Console.WriteLine( $"\tGoto -> {state.Goto.Label}" ); + if ( state.IsTerminal ) + Console.WriteLine( "\tTerminal" ); + Console.WriteLine(); + } + + return; + + + static string GetBinaryOperator( ExpressionType nodeType ) + { + return nodeType switch + { + ExpressionType.Assign => "=", + ExpressionType.GreaterThan => ">", + ExpressionType.LessThan => "<", + ExpressionType.Add => "+", + ExpressionType.Subtract => "-", + ExpressionType.Multiply => "*", + ExpressionType.Divide => "/", + _ => nodeType.ToString() + }; + } + + static string ExpressionToString( Expression expr ) + { + switch ( expr ) + { + case MethodCallExpression m: + var args = string.Join( ", ", m.Arguments.Select( ExpressionToString ) ); + return $"{m.Method.Name}({args})"; + case BinaryExpression b: + return $"{ExpressionToString( b.Left )} {GetBinaryOperator( b.NodeType )} {ExpressionToString( b.Right )}"; + case ParameterExpression p: + return p.Name; + case ConstantExpression c: + return c.Value?.ToString() ?? "empty"; + case GotoExpression g: + return $"goto {g.Target.Name}"; + case UnaryExpression u: + return $"{u.NodeType} {ExpressionToString( u.Operand )}"; + default: + return expr.ToString(); + } + } + } +} + +public class StateNode0 +{ + public int BlockId { get; } + public LabelTarget Label { get; set; } + public List Expressions { get; } = []; + public StateNode0 Final { get; set; } + + // Condition-specific fields + public StateNode0 IfTrue { get; set; } + public StateNode0 IfFalse { get; set; } + + // Switch-specific fields + public List Cases { get; set; } + + // For Async/Await fields + public int? ContinuationId { get; set; } + public StateNode0 Await { get; set; } + + // Goto-specific fields + public StateNode0 Continue { get; set; } + public StateNode0 Break { get; set; } + public StateNode0 Goto { get; set; } + + // TryCatch-specific fields + public StateNode0 Try { get; set; } + public List Catches { get; set; } + public StateNode0 Finally { get; set; } + public StateNode0 Fault { get; set; } + + public bool IsTerminal + { + get + { + return Final == null && + IfTrue == null && + IfFalse == null && + Cases == null && + Catches == null && + Await == null; + } + } + + public StateNode0( int blockId ) + { + BlockId = blockId; + Label = Expression.Label( $"block_{BlockId}" ); + } +} diff --git a/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs b/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs index 19ab7a7..6a67423 100644 --- a/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs +++ b/src/Hyperbee.AsyncExpressions/ParameterMappingVisitor.cs @@ -59,7 +59,7 @@ protected override Expression VisitExtension( Expression node ) Visit( awaitableBlock.After ); return node; case AwaitExpression awaitExpression: - return Visit( awaitExpression.Target ); + return Visit( awaitExpression.Target )!; default: return base.VisitExtension( node ); } diff --git a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs index c7a0eb4..5b9a5f9 100644 --- a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs +++ b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs @@ -597,6 +597,7 @@ public static Expression Create( BlockExpression source, bool createRun var stateMachineBuilder = new StateMachineBuilder( moduleBuilder, "DynamicStateMachine" ); stateMachineBuilder.SetExpressionSource( source ); + var stateMachineExpression = stateMachineBuilder.CreateStateMachine( createRunner ); return stateMachineExpression; diff --git a/src/Hyperbee.AsyncExpressions/StateNode.cs b/src/Hyperbee.AsyncExpressions/StateNode.cs index 1347fc9..96d6c4f 100644 --- a/src/Hyperbee.AsyncExpressions/StateNode.cs +++ b/src/Hyperbee.AsyncExpressions/StateNode.cs @@ -7,42 +7,7 @@ public class StateNode public int BlockId { get; } public LabelTarget Label { get; set; } public List Expressions { get; } = []; - public StateNode Final { get; set; } - - // Condition-specific fields - public StateNode IfTrue { get; set; } - public StateNode IfFalse { get; set; } - - // Switch-specific fields - public List Cases { get; set; } - - // For Async/Await fields - public int? ContinuationId { get; set; } - public StateNode Await { get; set; } - - // Goto-specific fields - public StateNode Continue { get; set; } - public StateNode Break { get; set; } - public StateNode Goto { get; set; } - - // TryCatch-specific fields - public StateNode Try { get; set; } - public List Catches { get; set; } - public StateNode Finally { get; set; } - public StateNode Fault { get; set; } - - public bool IsTerminal - { - get - { - return Final == null && - IfTrue == null && - IfFalse == null && - Cases == null && - Catches == null && - Await == null; - } - } + public TransitionNode Transition { get; set; } public StateNode( int blockId ) { diff --git a/src/Hyperbee.AsyncExpressions/Transition.cs b/src/Hyperbee.AsyncExpressions/Transition.cs new file mode 100644 index 0000000..51dbc04 --- /dev/null +++ b/src/Hyperbee.AsyncExpressions/Transition.cs @@ -0,0 +1,95 @@ +namespace Hyperbee.AsyncExpressions; + +public enum TransitionType +{ + Default, // Represents a no-op or default transition + Conditional, // Conditional transitions (e.g., if-else) + Switch, // Switch case transitions + TryCatch, // Try-catch-finally transitions + Loop, // Loop transitions + Await, // Await transition + Goto, // Goto transition + Label, // Label transition +} + +public abstract class TransitionNode +{ + public TransitionType TransitionType { get; } + public StateNode ContinueTo { get; set; } + + protected TransitionNode( TransitionType transitionType ) + { + TransitionType = transitionType; + } +} + +public class DefaultTransition : TransitionNode +{ + public DefaultTransition() + : base( TransitionType.Default ) + { + } +} + +public class ConditionalTransition : TransitionNode +{ + public StateNode IfTrue { get; set; } + public StateNode IfFalse { get; set; } + + public ConditionalTransition() + : base( TransitionType.Conditional ) + { + } +} + +public class SwitchTransition : TransitionNode +{ + public List CaseNodes { get; set; } = []; + public StateNode DefaultNode { get; set; } + + public SwitchTransition() + : base(TransitionType.Switch ) + { + } +} + +public class TryCatchTransition : TransitionNode +{ + public StateNode TryNode { get; set; } + public List CatchNodes { get; set; } = []; + public StateNode FinallyNode { get; set; } + + public TryCatchTransition() + : base(TransitionType.TryCatch ) + { + } +} + +public class AwaitTransition : TransitionNode +{ + public int ContinuationId { get; set; } = -1; + public StateNode CompletionNode { get; set; } + + public AwaitTransition() + : base( TransitionType.Await ) + { + } +} + +public class GotoTransition : TransitionNode +{ + public StateNode TargetNode { get; set; } + + public GotoTransition() + : base( TransitionType.Goto ) + { + } +} + +public class LabelTransition : TransitionNode +{ + public LabelTransition() + : base( TransitionType.Label ) + { + } +} diff --git a/test/Hyperbee.AsyncExpressions.Tests/GotoTransformerVisitorTests.cs b/test/Hyperbee.AsyncExpressions.Tests/GotoTransformerVisitorTests.cs index ede163c..8398ebc 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/GotoTransformerVisitorTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/GotoTransformerVisitorTests.cs @@ -51,7 +51,7 @@ public void GotoTransformer_WithIfThen() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( ifThenElseExpr ); // Assert @@ -76,7 +76,7 @@ public void GotoTransformer_WithSwitch() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( switchBlock ); // Assert @@ -101,7 +101,7 @@ public void GotoTransformer_WithSwitchAwaits() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( switchBlock ); // Assert @@ -124,7 +124,7 @@ public void GotoTransformer_WithMethodAwaits() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( callExpr ); // Assert @@ -145,7 +145,7 @@ public void GotoTransformer_WithMethodAwaitArguments() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( callExpr ); // Assert @@ -168,7 +168,7 @@ public void GotoTransformer_WithGoto() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( gotoExpr ); // Assert @@ -197,7 +197,7 @@ public void GotoTransformer_WithLoop() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( whileBlockExpr ); // Assert @@ -223,7 +223,7 @@ public void GotoTransformer_WithTryCatch() ); // Act - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( tryCatchExpr ); // Assert @@ -257,7 +257,7 @@ public void GotoTransformer_WithComplexConditions() Constant( 5 ) ); - var transformer = new GotoTransformerVisitor(); + var transformer = new GotoTransformerVisitor0(); transformer.Transform( ifThenElseExpr ); transformer.PrintStateMachine();