Skip to content

Commit

Permalink
add special naming
Browse files Browse the repository at this point in the history
  • Loading branch information
bfarmer67 committed Sep 11, 2024
1 parent b8270ab commit 550df03
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 49 deletions.
91 changes: 49 additions & 42 deletions src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ public class StateMachineBuilder<TResult>
private List<FieldBuilder> _variableFields;
private List<FieldBuilder> _awaiterFields;

private static class FieldName
{
// use special names to prevent collisions with user fields
public const string Builder = "__builder<>";
public const string FinalResult = "__finalResult<>";
public const string MoveNextLambda = "__moveNextLambda<>";
public const string State = "__state<>";
private const string AwaiterTemplate = "__awaiter<{0}>";

public static string Awaiter( int i ) => string.Format( AwaiterTemplate, i );
}

public StateMachineBuilder( ModuleBuilder moduleBuilder, string typeName )
{
_moduleBuilder = moduleBuilder;
Expand All @@ -61,7 +73,7 @@ public Expression CreateStateMachine( bool createRunner = true )
var moveNextLambda = CreateMoveNextExpression( _blockSource );

var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" );
var builderFieldInfo = _stateMachineType.GetField( "_builder" )!;
var builderFieldInfo = _stateMachineType.GetField( FieldName.Builder )!;
var setLambdaMethod = _stateMachineType.GetMethod( "SetMoveNext" )!;

var constructor = _stateMachineType.GetConstructor( Type.EmptyTypes )!;
Expand Down Expand Up @@ -92,7 +104,7 @@ public Expression CreateRunStateMachine( Expression stateMachineExpression )

var stateMachineVariable = Expression.Variable( _stateMachineType, "stateMachine" );

var builderFieldInfo = _stateMachineType.GetField( "_builder" )!;
var builderFieldInfo = _stateMachineType.GetField( FieldName.Builder )!;
var taskFieldInfo = builderFieldInfo.FieldType.GetProperty( "Task" )!;

var builderField = Expression.Field( stateMachineVariable, builderFieldInfo );
Expand Down Expand Up @@ -122,39 +134,34 @@ private void CreateStateMachineType( BlockExpression block )
//
// public class StateMachineType : IAsyncStateMachine
// {
// public int _state;
// public AsyncTaskMethodBuilder<TResult> _builder;
// public TResult _finalResult;
// public Action _moveNextLambda;
// public int __state<>;
// public AsyncTaskMethodBuilder<TResult> __builder<>;
// public TResult __finalResult<>;
// public Action __moveNextLambda<>;
//
// // Variables (example)
// public int _variable1;
// public int _variable2;
//
// // Awaiters (example)
// public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _awaiter1;
// public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _awaiter2;
// public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter __awaiter<1>;
// public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter __awaiter<2>;
//
// public StateMachineType()
// {
// }
//
// public void SetLambda<T>(Action<T> moveNextLambda)
// {
// Action<object> moveNext = obj => moveNextLambda( (StateMachineType) obj );
// moveNext(this);
// }
//
// public void MoveNext() => _moveNextLambda(this);
// public void SetStateMachine(IAsyncStateMachine stateMachine) => _builder.SetStateMachine( stateMachine );
// public void MoveNext() => __moveNextLambda<>(this);
// public void SetLambda<T>(Action<T> moveNextLambda) => __moveNextLambda<> = obj => moveNextLambda( (StateMachineType) obj );
// public void SetStateMachine(IAsyncStateMachine stateMachine) => __builder<>.SetStateMachine( stateMachine );
// }

_typeBuilder = _moduleBuilder.DefineType( _typeName, TypeAttributes.Public, typeof( object ), [typeof( IAsyncStateMachine )] );

_typeBuilder.DefineField( "_state", typeof( int ), FieldAttributes.Public );
_builderField = _typeBuilder.DefineField( "_builder", typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Public );
_finalResultField = _typeBuilder.DefineField( "_finalResult", typeof( TResult ), FieldAttributes.Public );
_moveNextLambdaField = _typeBuilder.DefineField( "_moveNextLambda", typeof( Action<> ).MakeGenericType( _typeBuilder ), FieldAttributes.Private );
_typeBuilder.DefineField( FieldName.State, typeof( int ), FieldAttributes.Public );
_builderField = _typeBuilder.DefineField( FieldName.Builder, typeof( AsyncTaskMethodBuilder<> ).MakeGenericType( typeof( TResult ) ), FieldAttributes.Public );
_finalResultField = _typeBuilder.DefineField( FieldName.FinalResult, typeof( TResult ), FieldAttributes.Public );
_moveNextLambdaField = _typeBuilder.DefineField( FieldName.MoveNextLambda, typeof( Action<> ).MakeGenericType( _typeBuilder ), FieldAttributes.Private );

EmitBlockFields( block );
EmitConstructor();
Expand Down Expand Up @@ -196,9 +203,9 @@ private void EmitBlockFields( BlockExpression block )
if ( !TryMakeAwaiterType( expr, out Type awaiterType ) )
continue; // Not an awaitable expression

var fieldName = $"_awaiter_{i}"; // `i` should match the index of the expression to align with state machine logic
// `i` should match the index of the expression to align with state machine logic
var awaiterField = _typeBuilder.DefineField( FieldName.Awaiter( i ), awaiterType, FieldAttributes.Public );

var awaiterField = _typeBuilder.DefineField( fieldName, awaiterType, FieldAttributes.Public );
_awaiterFields.Add( awaiterField );
}
}
Expand Down Expand Up @@ -299,39 +306,39 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )
// {
// try
// {
// if (_state == 0)
// if (__state<> == 0)
// {
// _awaiter1 = task1.ConfigureAwait(false).GetAwaiter();
// _state = 1;
// __awaiter<1> = task1.ConfigureAwait(false).GetAwaiter();
// __state<> = 1;
//
// if (!_awaiter1.IsCompleted == false)
// if (!__awaiter<1>.IsCompleted == false)
// {
// _builder.AwaitUnsafeOnCompleted(ref _awaiter1, this);
// __builder<>.AwaitUnsafeOnCompleted(ref __awaiter<1>, this);
// return;
// }
// }
//
// if (_state == 1)
// if (__state<> == 1)
// {
// _awaiter1.GetResult();
// _awaiter2 = task2.ConfigureAwait(false).GetAwaiter();
// _state = 2;
// __awaiter<1>.GetResult();
// __awaiter<2> = task2.ConfigureAwait(false).GetAwaiter();
// __state<> = 2;
//
// if (!_awaiter2.IsCompleted)
// if (!__awaiter<2>.IsCompleted)
// {
// _builder.AwaitUnsafeOnCompleted(ref _awaiter2, this);
// __builder<>.AwaitUnsafeOnCompleted(ref __awaiter<2>, this);
// return;
// }
// }
//
// if (_state == 2)
// if (__state<> == 2)
// {
// _builder.Task.SetResult( _awaiter2.GetResult() );
// __builder<>.Task.SetResult( __awaiter<2>.GetResult() );
// }
// }
// catch (Exception ex)
// {
// _builder.SetException(ex);
// __builder<>.SetException(ex);
// }
// }
var stateMachineInstance = Expression.Parameter( _stateMachineType, "stateMachine" );
Expand Down Expand Up @@ -378,7 +385,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )
);

// Increment state
var setStateBeforeAwait = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) );
var setStateBeforeAwait = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) );

// Check completed
var awaiterCompletedCheck = Expression.IfThen(
Expand All @@ -396,7 +403,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )
);

var stateCheck = Expression.IfThen(
Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i ) ),
Expression.Equal( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i ) ),
Expression.Block( assignAwaiter, setStateBeforeAwait, awaiterCompletedCheck )
);
bodyExpressions.Add( stateCheck );
Expand All @@ -406,14 +413,14 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )
if ( typeof(TResult) != typeof(IVoidTaskResult) )
{
var assignFinalResult = Expression.Assign( Expression.Field( stateMachineInstance, finalResultFieldInfo ), blockExpr );
var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) );
var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) );
bodyExpressions.Add( Expression.Block( assignFinalResult, incrementState ) );
handledFinalBlock = true;
}
else
{
// IVoidTaskResult (no result)
var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( i + 1 ) );
var incrementState = Expression.Assign( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( i + 1 ) );
bodyExpressions.Add( incrementState );
}
}
Expand All @@ -425,7 +432,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )

// Generate the final state
var finalState = Expression.IfThen(
Expression.Equal( Expression.Field( stateMachineInstance, "_state" ), Expression.Constant( lastBlockIndex + 1 ) ),
Expression.Equal( Expression.Field( stateMachineInstance, FieldName.State ), Expression.Constant( lastBlockIndex + 1 ) ),
Expression.Block(
// Handle the final result for Task and Task<T>
!handledFinalBlock && typeof( TResult) != typeof(IVoidTaskResult)
Expand All @@ -452,7 +459,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )

// Mark as completed after the final state logic is executed
var markCompletedState = Expression.Assign(
Expression.Field( stateMachineInstance, "_state" ),
Expression.Field( stateMachineInstance, FieldName.State ),
Expression.Constant( -2 ) // Mark the state machine as completed
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public enum ExpressionKind
}

[TestClass]
public class AsyncExpressionUnitTests
public class AsyncMethodTests
{
private static async Task Delay()
{
Expand All @@ -19,25 +19,25 @@ private static async Task Delay()

private static async Task<int> GetNumberAsync()
{
await Task.Delay( 100 );
await Task.Delay( 10 );
return 42;
}

private static async Task<int> AddTwoNumbersAsync( int a, int b )
{
await Task.Delay( 100 );
await Task.Delay( 10 );
return a + b;
}

private static async Task<int> AddThreeNumbersAsync( int a, int b, int c )
{
await Task.Delay( 100 );
await Task.Delay( 10 );
return a + b + c;
}

private static async Task<string> SayHelloAsync( int a )
{
await Task.Delay( 100 );
await Task.Delay( 10 );
return $"Hello {a}";
}

Expand All @@ -48,13 +48,13 @@ private static int IncrementValue( int a )

private static async Task<int> ThrowExceptionAsync()
{
await Task.Delay( 100 );
await Task.Delay( 10 );
throw new InvalidOperationException( "Simulated exception." );
}

private static MethodInfo GetMethodInfo( string name )
{
return typeof( AsyncExpressionUnitTests ).GetMethod( name, BindingFlags.Static | BindingFlags.NonPublic )!;
return typeof( AsyncMethodTests ).GetMethod( name, BindingFlags.Static | BindingFlags.NonPublic )!;
}

private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, MethodInfo methodInfo, params Expression[] arguments )
Expand Down

0 comments on commit 550df03

Please sign in to comment.