From 9a90f4d72193ff4f917987450cad8f06c5dbf53d Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Tue, 10 Sep 2024 16:33:29 -0700 Subject: [PATCH] More cleanup --- .../AwaitExpression.cs | 9 ++++----- .../StateMachineBuilder.cs | 18 +++++++++--------- .../AsyncBlockTests.cs | 4 ++-- .../UnitTests.cs | 12 ++++++------ 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs index b1fecfa..e36b3b5 100644 --- a/src/Hyperbee.AsyncExpressions/AwaitExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AwaitExpression.cs @@ -27,7 +27,7 @@ internal AwaitExpression( Expression asyncExpression, bool configureAwait ) // TODO: Review with BF (fix caching the type) public override Type Type => ResultType( _asyncExpression.Type ); //_resultType; - public Expression AsyncExpression => _asyncExpression; + public Expression Target => _asyncExpression; public bool ReturnTask { get; set; } @@ -39,11 +39,10 @@ public override Expression Reduce() return _asyncExpression; var resultType = ResultType( _asyncExpression.Type ); - var awaitExpression = Call( resultType == typeof(void) || resultType == typeof( IVoidTaskResult ) + + return Call( resultType == typeof(void) || resultType == typeof( IVoidTaskResult ) ? AwaitMethod : AwaitResultMethod.MakeGenericMethod( resultType ), _asyncExpression, Constant( _configureAwait ) ); - - return awaitExpression; } private Type ResultType( Type taskType ) @@ -79,7 +78,7 @@ private class AwaitExpressionProxy( AwaitExpression node ) public static partial class AsyncExpression { - public static AwaitExpression Await( Expression expression, bool configureAwait ) + public static AwaitExpression Await( Expression expression, bool configureAwait = false ) { return new AwaitExpression( expression, configureAwait ); } diff --git a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs index dad11b8..dbda750 100644 --- a/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs +++ b/src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs @@ -193,7 +193,7 @@ private void EmitBlockFields( BlockExpression block ) { var expr = block.Expressions[i]; - if ( !TryGetAwaiterType( expr, out Type awaiterType ) ) + if ( !TryMakeAwaiterType( expr, out Type awaiterType ) ) continue; // Not an awaitable expression var fieldName = $"_awaiter_{i}"; // `i` should match the index of the expression to align with state machine logic @@ -433,7 +433,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block ) ? Expression.Assign( Expression.Field( stateMachineInstance, finalResultFieldInfo ), Expression.Call( - Expression.Field( stateMachineInstance, lastAwaitField ), + Expression.Field( stateMachineInstance, lastAwaitField! ), "GetResult", Type.EmptyTypes ) ) @@ -488,35 +488,35 @@ private static FieldInfo GetFieldInfo( Type runtimeType, FieldBuilder field ) return runtimeType.GetField( field.Name, BindingFlags.Instance | BindingFlags.Public )!; } - private static bool TryGetAwaiterType( Expression expr, out Type awaiterType ) + private static bool TryMakeAwaiterType( Expression expr, out Type awaiterType ) { awaiterType = null; switch ( expr ) { case MethodCallExpression methodCall when typeof( Task ).IsAssignableFrom( methodCall.Type ): - awaiterType = GetAwaiterType( methodCall.Type ); + awaiterType = MakeAwaiterType( methodCall.Type ); return true; case InvocationExpression invocation when typeof( Task ).IsAssignableFrom( invocation.Type ): - awaiterType = GetAwaiterType( invocation.Type ); + awaiterType = MakeAwaiterType( invocation.Type ); return true; case BlockExpression block: - return TryGetAwaiterType( block.Expressions.Last(), out awaiterType ); + return TryMakeAwaiterType( block.Expressions.Last(), out awaiterType ); case AwaitExpression await: - awaiterType = GetAwaiterType( await.AsyncExpression.Type ); + awaiterType = MakeAwaiterType( await.Target.Type ); return true; case not null when typeof( Task ).IsAssignableFrom( expr.Type ): - awaiterType = GetAwaiterType( expr.Type ); + awaiterType = MakeAwaiterType( expr.Type ); return true; } return false; - static Type GetAwaiterType( Type taskType ) + static Type MakeAwaiterType( Type taskType ) { if ( !taskType.IsGenericType ) return typeof( ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ); diff --git a/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs b/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs index fa4f536..6f1a47f 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs @@ -95,7 +95,7 @@ public void TestAsyncBlock_SimpleBlockSplitting() } [TestMethod] - public async Task TestAsyncBlock_StartStateMachine() + public async Task TestAsyncBlock_WithoutParameters_ReturnsResult() { // Arrange var expr1 = Expression.Constant( 1 ); @@ -119,7 +119,7 @@ public async Task TestAsyncBlock_StartStateMachine() } [TestMethod] - public async Task TestAsyncBlock_StartStateMachineWithVariables() + public async Task TestAsyncBlock_WithParameters_ReturnsResult() { // Arrange var param1 = Expression.Parameter( typeof( int ), "param1" ); diff --git a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs index df75ffb..218d3fe 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs @@ -25,7 +25,7 @@ private static async Task GetNumberAsync() private static async Task AddTwoNumbersAsync( int a, int b ) { - await Task.Delay( 10 ); + await Task.Delay( 100 ); return a + b; } @@ -37,7 +37,7 @@ private static async Task AddThreeNumbersAsync( int a, int b, int c ) private static async Task SayHelloAsync( int a ) { - await Task.Delay( 10 ); + await Task.Delay( 100 ); return $"Hello {a}"; } @@ -48,7 +48,7 @@ private static int IncrementValue( int a ) private static async Task ThrowExceptionAsync() { - await Task.Delay( 50 ); + await Task.Delay( 100 ); throw new InvalidOperationException( "Simulated exception." ); } @@ -336,8 +336,8 @@ public async Task TestMultipleAsyncExpressions_WithDeepNestingAsyncAwait() var paramA = Expression.Parameter( typeof( int ), "a" ); - var l1 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, paramA ), false ); - var l2 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, l1 ), false ); + var l1 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, paramA ), configureAwait: false ); + var l2 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, l1 ), configureAwait: false ); var l3 = AsyncExpression.InvokeAsync( incrementExpression, l2 ); var compiled = Expression.Lambda>>( l3, paramA ).Compile(); @@ -412,7 +412,7 @@ public void TestAsyncExpression_ExceptionHandling( ExpressionKind kind ) try { _ = compiledLambda(); // Directly get the unwrapped result - Assert.Fail( "Expected exception was not thrown." ); + Assert.Fail( "An exception was not thrown." ); } catch ( InvalidOperationException ex ) {