Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bfarmer67 committed Aug 31, 2024
1 parent b0a45ea commit 18ccd85
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 52 deletions.
36 changes: 18 additions & 18 deletions src/Hyperbee.AsyncExpressions/AsyncBaseExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,34 @@ public override Expression Reduce()
if ( _isReduced )
return _reducedBody;

_reducedBody = ReducedBody( _body );
_isReduced = true;

_reducedBody = _body.Type.IsGenericType switch
{
true => GetReducedBody( _body.Type.GetGenericArguments()[0], _body ),
false => GetReducedBody( typeof( VoidResult ), Block( _body, TaskVoidResult ) )
};

return _reducedBody;

static Expression GetReducedBody( Type type, Expression body )
static Expression ReducedBody( Expression body )
{
var methodInfo = MakeExecuteAsyncExpressionMethod.MakeGenericMethod( type );
return (Expression) methodInfo!.Invoke( null, [body] );
var returnType = ReturnType( body );
var bodyToUse = BodyToUse( body );

var methodInfo = MakeExecuteAsyncExpressionMethod.MakeGenericMethod( returnType );
return (Expression) methodInfo.Invoke( null, [bodyToUse] );
}

static Type ReturnType( Expression body ) => body.Type.IsGenericType ? body.Type.GetGenericArguments()[0] : typeof(VoidResult);
static Expression BodyToUse( Expression body ) => body.Type.IsGenericType ? body : Block( body, TaskVoidResult );
}

private static BlockExpression MakeExecuteAsyncExpression<T>( Expression task )
{
/* Generate code block:
internal static Task<T> ExecuteAsync<T>(Task<T> task)
{
var stateMachine = new StateMachine<T>(task);
stateMachine.MoveNext();
return stateMachine.Task;
}
*/
// Generate code block:
//
// internal static Task<T> ExecuteAsync<T>(Task<T> task)
// {
// var stateMachine = new StateMachine<T>(task);
// stateMachine.MoveNext();
// return stateMachine.Task;
// }

// Create unique variable names to avoid conflicts
var id = Interlocked.Increment( ref __stateMachineCounter );
Expand Down
69 changes: 35 additions & 34 deletions test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, Meth
switch ( kind )
{
case ExpressionKind.Lambda:
var (lambdaExpression, parameters) = GetLambdaExpression( methodInfo, arguments );
return AsyncExpression.InvokeAsync( lambdaExpression, parameters );
var (lambdaExpression, lambdaArguments) = GetLambdaExpression( methodInfo, arguments );
return AsyncExpression.InvokeAsync( lambdaExpression, lambdaArguments );

case ExpressionKind.Method:
return AsyncExpression.CallAsync( methodInfo, arguments );
Expand All @@ -73,19 +73,20 @@ private static AsyncBaseExpression GetAsyncExpression( ExpressionKind kind, Meth
}
}

private static (LambdaExpression Lambda, Expression[] Parameters) GetLambdaExpression( MethodInfo methodInfo, params Expression[] arguments )
private static (LambdaExpression Lambda, Expression[] Arguments) GetLambdaExpression( MethodInfo methodInfo, params Expression[] arguments )
{
if ( methodInfo.GetParameters().Length != arguments.Length )
{
throw new ArgumentException( "Number of arguments does not match the number of method parameters." );
}

var parameterExpressions = arguments.OfType<ParameterExpression>();
var parameterExpressions = arguments.OfType<ParameterExpression>().ToArray();
var lambdaArguments = parameterExpressions.Cast<Expression>().ToArray();

var callExpression = Expression.Call( methodInfo, arguments );
var lambdaExpression = Expression.Lambda( callExpression, arguments.OfType<ParameterExpression>() );
var lambdaExpression = Expression.Lambda( callExpression, parameterExpressions );

return (lambdaExpression, parameterExpressions.Cast<Expression>().ToArray());
return (lambdaExpression, lambdaArguments);
}

[DataTestMethod]
Expand Down Expand Up @@ -166,55 +167,55 @@ public void TestAsyncExpression_WithConstants( ExpressionKind kind )
[DataTestMethod]
[DataRow( ExpressionKind.Lambda )]
[DataRow( ExpressionKind.Method )]
public void TestAsyncExpression_WithMethodParameters( ExpressionKind kind )
public void TestAsyncExpression_WithAsyncParameter( ExpressionKind kind )
{
// var result0 = IncrementValue( 11 );
// var result1 = await AddThreeNumbersAsync( 10, 20, result0 );
// var result = await SayHelloAsync( await AddTwoNumbersAsync( 10, 32 ) );

var incrementMethodInfo = GetMethodInfo( nameof( IncrementValue ) );
var methodInfo = GetMethodInfo( nameof( AddThreeNumbersAsync ) );
var addTwoNumbersMethod = GetMethodInfo( nameof(AddTwoNumbersAsync) );
var sayHelloMethod = GetMethodInfo( nameof(SayHelloAsync) );

var paramExpr1 = Expression.Parameter( typeof( int ), "a" );
var paramExpr2 = Expression.Parameter( typeof( int ), "b" );
var paramExpr3 = Expression.Parameter( typeof( int ), "c" );
var paramA = Expression.Parameter( typeof(int), "a" );
var paramB = Expression.Parameter( typeof(int), "b" );

var incrementValueCall = Expression.Call( incrementMethodInfo!, paramExpr3 );
var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod, paramA, paramB );
var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false );

var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, incrementValueCall );
var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false );
var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod, awaitExpressionAdd );
var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false );

var lambda = Expression.Lambda<Func<int, int, int, int>>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 );
var lambda = Expression.Lambda<Func<int, int, string>>( awaitExpressionSayHello, paramA, paramB );
var compiledLambda = lambda.Compile();

var result = compiledLambda( 10, 20, 11 ); // Pass 10, 20, and 11 as parameters; IncrementValue will increment 11
Assert.AreEqual( 42, result, "The result should be 42." );
var result = compiledLambda( 10, 32 );

Assert.AreEqual( "Hello 42", result, "The result should be 'Hello 42'." );
}

[DataTestMethod]
[DataRow( ExpressionKind.Lambda )]
[DataRow( ExpressionKind.Method )]
public void TestAsyncExpression_AsParameter( ExpressionKind kind )
public void TestAsyncExpression_WithMethodCallParameters( ExpressionKind kind )
{
// var result = await SayHelloAsync( await AddTwoNumbersAsync( 10, 32 ) );
// var result0 = IncrementValue( 11 );
// var result1 = await AddThreeNumbersAsync( 10, 20, result0 );

var addTwoNumbersMethod = GetMethodInfo( nameof( AddTwoNumbersAsync ) );
var sayHelloMethod = GetMethodInfo( nameof( SayHelloAsync ) );
var incrementMethodInfo = GetMethodInfo( nameof( IncrementValue ) );
var methodInfo = GetMethodInfo( nameof( AddThreeNumbersAsync ) );

var paramA = Expression.Parameter( typeof( int ), "a" );
var paramB = Expression.Parameter( typeof( int ), "b" );
var paramExpr1 = Expression.Parameter( typeof( int ), "a" );
var paramExpr2 = Expression.Parameter( typeof( int ), "b" );
var paramExpr3 = Expression.Parameter( typeof( int ), "c" );

var asyncExpressionAdd = GetAsyncExpression( kind, addTwoNumbersMethod, paramA, paramB );
var awaitExpressionAdd = AsyncExpression.Await( asyncExpressionAdd, configureAwait: false );
var incrementValueCall = Expression.Call( incrementMethodInfo!, paramExpr3 );

var asyncExpressionSayHello = GetAsyncExpression( kind, sayHelloMethod, awaitExpressionAdd );
var awaitExpressionSayHello = AsyncExpression.Await( asyncExpressionSayHello, configureAwait: false );
var asyncExpression = GetAsyncExpression( kind, methodInfo!, paramExpr1, paramExpr2, incrementValueCall );
var awaitExpression = AsyncExpression.Await( asyncExpression, configureAwait: false );

var lambda = Expression.Lambda<Func<int, int, string>>( awaitExpressionSayHello, paramA, paramB );
var lambda = Expression.Lambda<Func<int, int, int, int>>( awaitExpression, paramExpr1, paramExpr2, paramExpr3 );
var compiledLambda = lambda.Compile();

var result = compiledLambda( 10, 32 );

Assert.AreEqual( "Hello 42", result, "The result should be 'Hello 42'." );
var result = compiledLambda( 10, 20, 11 ); // Pass 10, 20, and 11 as parameters; IncrementValue will increment 11
Assert.AreEqual( 42, result, "The result should be 42." );
}

[DataTestMethod]
Expand Down

0 comments on commit 18ccd85

Please sign in to comment.