From 18ccd85d5eda985c4b842b4fcb0717526cfc786e Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Fri, 30 Aug 2024 21:05:23 -0700 Subject: [PATCH] Cleanup --- .../AsyncBaseExpression.cs | 36 +++++----- .../UnitTests.cs | 69 ++++++++++--------- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs b/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs index d962363..aa6c1ed 100644 --- a/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs +++ b/src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs @@ -40,34 +40,34 @@ public override Expression Reduce() if ( _isReduced ) return _reducedBody; + _reducedBody = ReducedBody( _body ); _isReduced = true; - _reducedBody = _body.Type.IsGenericType switch - { - true => GetReducedBody( _body.Type.GetGenericArguments()[0], _body ), - false => GetReducedBody( typeof( VoidResult ), Block( _body, TaskVoidResult ) ) - }; - return _reducedBody; - static Expression GetReducedBody( Type type, Expression body ) + static Expression ReducedBody( Expression body ) { - var methodInfo = MakeExecuteAsyncExpressionMethod.MakeGenericMethod( type ); - return (Expression) methodInfo!.Invoke( null, [body] ); + var returnType = ReturnType( body ); + var bodyToUse = BodyToUse( body ); + + var methodInfo = MakeExecuteAsyncExpressionMethod.MakeGenericMethod( returnType ); + return (Expression) methodInfo.Invoke( null, [bodyToUse] ); } + + static Type ReturnType( Expression body ) => body.Type.IsGenericType ? body.Type.GetGenericArguments()[0] : typeof(VoidResult); + static Expression BodyToUse( Expression body ) => body.Type.IsGenericType ? body : Block( body, TaskVoidResult ); } private static BlockExpression MakeExecuteAsyncExpression( Expression task ) { - /* Generate code block: - - internal static Task ExecuteAsync(Task task) - { - var stateMachine = new StateMachine(task); - stateMachine.MoveNext(); - return stateMachine.Task; - } - */ + // Generate code block: + // + // internal static Task ExecuteAsync(Task task) + // { + // var stateMachine = new StateMachine(task); + // stateMachine.MoveNext(); + // return stateMachine.Task; + // } // Create unique variable names to avoid conflicts var id = Interlocked.Increment( ref __stateMachineCounter ); diff --git a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs index 100dd30..9220037 100644 --- a/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs +++ b/test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs @@ -62,8 +62,8 @@ private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, Meth switch ( kind ) { case ExpressionKind.Lambda: - var (lambdaExpression, parameters) = GetLambdaExpression( methodInfo, arguments ); - return AsyncExpression.InvokeAsync( lambdaExpression, parameters ); + var (lambdaExpression, lambdaArguments) = GetLambdaExpression( methodInfo, arguments ); + return AsyncExpression.InvokeAsync( lambdaExpression, lambdaArguments ); case ExpressionKind.Method: return AsyncExpression.CallAsync( methodInfo, arguments ); @@ -73,19 +73,20 @@ private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, Meth } } - private static (LambdaExpression Lambda, Expression[] Parameters) GetLambdaExpression( MethodInfo methodInfo, params Expression[] arguments ) + private static (LambdaExpression Lambda, Expression[] Arguments) GetLambdaExpression( MethodInfo methodInfo, params Expression[] arguments ) { if ( methodInfo.GetParameters().Length != arguments.Length ) { throw new ArgumentException( "Number of arguments does not match the number of method parameters." ); } - var parameterExpressions = arguments.OfType(); + var parameterExpressions = arguments.OfType().ToArray(); + var lambdaArguments = parameterExpressions.Cast().ToArray(); var callExpression = Expression.Call( methodInfo, arguments ); - var lambdaExpression = Expression.Lambda( callExpression, arguments.OfType() ); + var lambdaExpression = Expression.Lambda( callExpression, parameterExpressions ); - return (lambdaExpression, parameterExpressions.Cast().ToArray()); + return (lambdaExpression, lambdaArguments); } [DataTestMethod] @@ -166,55 +167,55 @@ public void TestAsyncExpression_WithConstants( ExpressionKind kind ) [DataTestMethod] [DataRow( ExpressionKind.Lambda )] [DataRow( ExpressionKind.Method )] - public void TestAsyncExpression_WithMethodParameters( ExpressionKind kind ) + public void TestAsyncExpression_WithAsyncParameter( ExpressionKind kind ) { - // var result0 = IncrementValue( 11 ); - // var result1 = await AddThreeNumbersAsync( 10, 20, result0 ); + // var result = await SayHelloAsync( await AddTwoNumbersAsync( 10, 32 ) ); - var incrementMethodInfo = GetMethodInfo( nameof( IncrementValue ) ); - var methodInfo = GetMethodInfo( nameof( AddThreeNumbersAsync ) ); + var addTwoNumbersMethod = GetMethodInfo( nameof(AddTwoNumbersAsync) ); + var sayHelloMethod = GetMethodInfo( nameof(SayHelloAsync) ); - var paramExpr1 = Expression.Parameter( typeof( int ), "a" ); - var paramExpr2 = Expression.Parameter( typeof( int ), "b" ); - var paramExpr3 = Expression.Parameter( typeof( int ), "c" ); + var paramA = Expression.Parameter( typeof(int), "a" ); + var paramB = Expression.Parameter( typeof(int), "b" ); - var incrementValueCall = Expression.Call( incrementMethodInfo!, paramExpr3 ); + var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod, paramA, paramB ); + var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false ); - var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, incrementValueCall ); - var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); + var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod, awaitExpressionAdd ); + var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false ); - var lambda = Expression.Lambda>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 ); + var lambda = Expression.Lambda>( awaitExpressionSayHello, paramA, paramB ); var compiledLambda = lambda.Compile(); - var result = compiledLambda( 10, 20, 11 ); // Pass 10, 20, and 11 as parameters; IncrementValue will increment 11 - Assert.AreEqual( 42, result, "The result should be 42." ); + var result = compiledLambda( 10, 32 ); + + Assert.AreEqual( "Hello 42", result, "The result should be 'Hello 42'." ); } [DataTestMethod] [DataRow( ExpressionKind.Lambda )] [DataRow( ExpressionKind.Method )] - public void TestAsyncExpression_AsParameter( ExpressionKind kind ) + public void TestAsyncExpression_WithMethodCallParameters( ExpressionKind kind ) { - // var result = await SayHelloAsync( await AddTwoNumbersAsync( 10, 32 ) ); + // var result0 = IncrementValue( 11 ); + // var result1 = await AddThreeNumbersAsync( 10, 20, result0 ); - var addTwoNumbersMethod = GetMethodInfo( nameof( AddTwoNumbersAsync ) ); - var sayHelloMethod = GetMethodInfo( nameof( SayHelloAsync ) ); + var incrementMethodInfo = GetMethodInfo( nameof( IncrementValue ) ); + var methodInfo = GetMethodInfo( nameof( AddThreeNumbersAsync ) ); - var paramA = Expression.Parameter( typeof( int ), "a" ); - var paramB = Expression.Parameter( typeof( int ), "b" ); + var paramExpr1 = Expression.Parameter( typeof( int ), "a" ); + var paramExpr2 = Expression.Parameter( typeof( int ), "b" ); + var paramExpr3 = Expression.Parameter( typeof( int ), "c" ); - var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod, paramA, paramB ); - var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false ); + var incrementValueCall = Expression.Call( incrementMethodInfo!, paramExpr3 ); - var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod, awaitExpressionAdd ); - var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false ); + var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, incrementValueCall ); + var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false ); - var lambda = Expression.Lambda>( awaitExpressionSayHello, paramA, paramB ); + var lambda = Expression.Lambda>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 ); var compiledLambda = lambda.Compile(); - var result = compiledLambda( 10, 32 ); - - Assert.AreEqual( "Hello 42", result, "The result should be 'Hello 42'." ); + var result = compiledLambda( 10, 20, 11 ); // Pass 10, 20, and 11 as parameters; IncrementValue will increment 11 + Assert.AreEqual( 42, result, "The result should be 42." ); } [DataTestMethod]