From 550df035441b208bb89cdb6045edc069fff0d596 Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Tue, 10 Sep 2024 17:42:15 -0700 Subject: [PATCH] add special naming --- .../StateMachineBuilder.cs | 91 ++++++++++--------- .../{UnitTests.cs => AsyncMethodTests.cs} | 14 +-- 2 files changed, 56 insertions(+), 49 deletions(-) rename test/Hyperbee.AsyncExpressions.Tests/{UnitTests.cs => AsyncMethodTests.cs} (98%) diff --git a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs index be2e311..7d83b5f 100644 --- a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs +++ b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs @@ -36,6 +36,18 @@ public class StateMachineBuilder private List _variableFields; private List _awaiterFields; + private static class FieldName + { + // use special names to prevent collisions with user fields + public const string Builder = "__builder<>"; + public const string FinalResult = "__finalResult<>"; + public const string MoveNextLambda = "__moveNextLambda<>"; + public const string State = "__state<>"; + private const string AwaiterTemplate = "__awaiter<{0}>"; + + public static string Awaiter( int i ) => string.Format( AwaiterTemplate, i ); + } + public StateMachineBuilder( ModuleBuilder moduleBuilder, string typeName ) { _moduleBuilder = moduleBuilder; @@ -61,7 +73,7 @@ public Expression CreateStateMachine( bool createRunner = true ) var moveNextLambda = CreateMoveNextExpression( _blockSource ); var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" ); - var builderFieldInfo = _stateMachineType.GetField( "_builder" )!; + var builderFieldInfo = _stateMachineType.GetField( FieldName.Builder )!; var setLambdaMethod = _stateMachineType.GetMethod( "SetMoveNext" )!; var constructor = _stateMachineType.GetConstructor( Type.EmptyTypes )!; @@ -92,7 +104,7 @@ public Expression CreateRunStateMachine( Expression stateMachineExpression ) var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" ); - var builderFieldInfo = _stateMachineType.GetField( "_builder" )!; + var builderFieldInfo = _stateMachineType.GetField( FieldName.Builder )!; var taskFieldInfo = builderFieldInfo.FieldType.GetProperty( "Task" )!; var builderField = Expression.Field( stateMachineVariable, builderFieldInfo ); @@ -122,39 +134,34 @@ private void CreateStateMachineType( BlockExpression block ) // // public class StateMachineType : IAsyncStateMachine // { - // public int _state; - // public AsyncTaskMethodBuilder _builder; - // public TResult _finalResult; - // public Action _moveNextLambda; + // public int __state<>; + // public AsyncTaskMethodBuilder __builder<>; + // public TResult __finalResult<>; + // public Action __moveNextLambda<>; // // // Variables (example) // public int _variable1; // public int _variable2; // // // Awaiters (example) - // public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _awaiter1; - // public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _awaiter2; + // public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter __awaiter<1>; + // public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter __awaiter<2>; // // public StateMachineType() // { // } // - // public void SetLambda(Action moveNextLambda) - // { - // Action moveNext = obj => moveNextLambda( (StateMachineType) obj ); - // moveNext(this); - // } - // - // public void MoveNext() => _moveNextLambda(this); - // public void SetStateMachine(IAsyncStateMachine stateMachine) => _builder.SetStateMachine( stateMachine ); + // public void MoveNext() => __moveNextLambda<>(this); + // public void SetLambda(Action moveNextLambda) => __moveNextLambda<> = obj => moveNextLambda( (StateMachineType) obj ); + // public void SetStateMachine(IAsyncStateMachine stateMachine) => __builder<>.SetStateMachine( stateMachine ); // } _typeBuilder = _moduleBuilder.DefineType( _typeName, TypeAttributes.Public, typeof( object ), [typeof( IAsyncStateMachine )] ); - _typeBuilder.DefineField( "_state", typeof( int ), FieldAttributes.Public ); - _builderField = _typeBuilder.DefineField( "_builder", typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Public ); - _finalResultField = _typeBuilder.DefineField( "_finalResult", typeof( TResult ), FieldAttributes.Public ); - _moveNextLambdaField = _typeBuilder.DefineField( "_moveNextLambda", typeof( Action<> ).MakeGenericType( _typeBuilder ), FieldAttributes.Private ); + _typeBuilder.DefineField( FieldName.State, typeof( int ), FieldAttributes.Public ); + _builderField = _typeBuilder.DefineField( FieldName.Builder, typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Public ); + _finalResultField = _typeBuilder.DefineField( FieldName.FinalResult, typeof( TResult ), FieldAttributes.Public ); + _moveNextLambdaField = _typeBuilder.DefineField( FieldName.MoveNextLambda, typeof( Action<> ).MakeGenericType( _typeBuilder ), FieldAttributes.Private ); EmitBlockFields( block ); EmitConstructor(); @@ -196,9 +203,9 @@ private void EmitBlockFields( BlockExpression block ) if ( !TryMakeAwaiterType( expr, out Type awaiterType ) ) continue; // Not an awaitable expression - var fieldName = $"_awaiter_{i}"; // `i` should match the index of the expression to align with state machine logic + // `i` should match the index of the expression to align with state machine logic + var awaiterField = _typeBuilder.DefineField( FieldName.Awaiter( i ), awaiterType, FieldAttributes.Public ); - var awaiterField = _typeBuilder.DefineField( fieldName, awaiterType, FieldAttributes.Public ); _awaiterFields.Add( awaiterField ); } } @@ -299,39 +306,39 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) // { // try // { - // if (_state == 0) + // if (__state<> == 0) // { - // _awaiter1 = task1.ConfigureAwait(false).GetAwaiter(); - // _state = 1; + // __awaiter<1> = task1.ConfigureAwait(false).GetAwaiter(); + // __state<> = 1; // - // if (!_awaiter1.IsCompleted == false) + // if (!__awaiter<1>.IsCompleted == false) // { - // _builder.AwaitUnsafeOnCompleted(ref _awaiter1, this); + // __builder<>.AwaitUnsafeOnCompleted(ref __awaiter<1>, this); // return; // } // } // - // if (_state == 1) + // if (__state<> == 1) // { - // _awaiter1.GetResult(); - // _awaiter2 = task2.ConfigureAwait(false).GetAwaiter(); - // _state = 2; + // __awaiter<1>.GetResult(); + // __awaiter<2> = task2.ConfigureAwait(false).GetAwaiter(); + // __state<> = 2; // - // if (!_awaiter2.IsCompleted) + // if (!__awaiter<2>.IsCompleted) // { - // _builder.AwaitUnsafeOnCompleted(ref _awaiter2, this); + // __builder<>.AwaitUnsafeOnCompleted(ref __awaiter<2>, this); // return; // } // } // - // if (_state == 2) + // if (__state<> == 2) // { - // _builder.Task.SetResult( _awaiter2.GetResult() ); + // __builder<>.Task.SetResult( __awaiter<2>.GetResult() ); // } // } // catch (Exception ex) // { - // _builder.SetException(ex); + // __builder<>.SetException(ex); // } // } var stateMachineInstance = Expression.Parameter( _stateMachineType, "stateMachine" ); @@ -378,7 +385,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) ); // Increment state - var setStateBeforeAwait = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ); + var setStateBeforeAwait = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) ); // Check completed var awaiterCompletedCheck = Expression.IfThen( @@ -396,7 +403,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) ); var stateCheck = Expression.IfThen( - Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i ) ), + Expression.Equal( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i ) ), Expression.Block( assignAwaiter, setStateBeforeAwait, awaiterCompletedCheck ) ); bodyExpressions.Add( stateCheck ); @@ -406,14 +413,14 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) if ( typeof(TResult) != typeof(IVoidTaskResult) ) { var assignFinalResult = Expression.Assign( Expression.Field( stateMachineInstance, finalResultFieldInfo ), blockExpr ); - var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ); + var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) ); bodyExpressions.Add( Expression.Block( assignFinalResult, incrementState ) ); handledFinalBlock = true; } else { // IVoidTaskResult (no result) - var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) ); + var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) ); bodyExpressions.Add( incrementState ); } } @@ -425,7 +432,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) // Generate the final state var finalState = Expression.IfThen( - Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( lastBlockIndex + 1 ) ), + Expression.Equal( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( lastBlockIndex + 1 ) ), Expression.Block( // Handle the final result for Task and Task !handledFinalBlock && typeof( TResult) != typeof(IVoidTaskResult) @@ -452,7 +459,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) // Mark as completed after the final state logic is executed var markCompletedState = Expression.Assign( - Expression.Field( stateMachineInstance, "_state" ), + Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( -2 ) // Mark the state machine as completed ); diff --git a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs b/test/Hyperbee.AsyncExpressions.Tests/AsyncMethodTests.cs similarity index 98% rename from test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs rename to test/Hyperbee.AsyncExpressions.Tests/AsyncMethodTests.cs index 0d8fd5b..6c95ad5 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/AsyncMethodTests.cs @@ -10,7 +10,7 @@ public enum ExpressionKind } [TestClass] -public class AsyncExpressionUnitTests +public class AsyncMethodTests { private static async Task Delay() { @@ -19,25 +19,25 @@ private static async Task Delay() private static async Task GetNumberAsync() { - await Task.Delay( 100 ); + await Task.Delay( 10 ); return 42; } private static async Task AddTwoNumbersAsync( int a, int b ) { - await Task.Delay( 100 ); + await Task.Delay( 10 ); return a + b; } private static async Task AddThreeNumbersAsync( int a, int b, int c ) { - await Task.Delay( 100 ); + await Task.Delay( 10 ); return a + b + c; } private static async Task SayHelloAsync( int a ) { - await Task.Delay( 100 ); + await Task.Delay( 10 ); return $"Hello {a}"; } @@ -48,13 +48,13 @@ private static int IncrementValue( int a ) private static async Task ThrowExceptionAsync() { - await Task.Delay( 100 ); + await Task.Delay( 10 ); throw new InvalidOperationException( "Simulated exception." ); } private static MethodInfo GetMethodInfo( string name ) { - return typeof( AsyncExpressionUnitTests ).GetMethod( name, BindingFlags.Static | BindingFlags.NonPublic )!; + return typeof( AsyncMethodTests ).GetMethod( name, BindingFlags.Static | BindingFlags.NonPublic )!; } private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, MethodInfo methodInfo, params Expression[] arguments )