Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bfarmer67 committed Sep 10, 2024
1 parent cfe4862 commit 9a90f4d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
9 changes: 4 additions & 5 deletions src/Hyperbee.AsyncExpressions/AwaitExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal AwaitExpression( Expression asyncExpression, bool configureAwait )
// TODO: Review with BF (fix caching the type)
public override Type Type => ResultType( _asyncExpression.Type ); //_resultType;

public Expression AsyncExpression => _asyncExpression;
public Expression Target => _asyncExpression;

public bool ReturnTask { get; set; }

Expand All @@ -39,11 +39,10 @@ public override Expression Reduce()
return _asyncExpression;

var resultType = ResultType( _asyncExpression.Type );
var awaitExpression = Call( resultType == typeof(void) || resultType == typeof( IVoidTaskResult )

return Call( resultType == typeof(void) || resultType == typeof( IVoidTaskResult )
? AwaitMethod
: AwaitResultMethod.MakeGenericMethod( resultType ), _asyncExpression, Constant( _configureAwait ) );

return awaitExpression;
}

private Type ResultType( Type taskType )
Expand Down Expand Up @@ -79,7 +78,7 @@ private class AwaitExpressionProxy( AwaitExpression node )

public static partial class AsyncExpression
{
public static AwaitExpression Await( Expression expression, bool configureAwait )
public static AwaitExpression Await( Expression expression, bool configureAwait = false )
{
return new AwaitExpression( expression, configureAwait );
}
Expand Down
18 changes: 9 additions & 9 deletions src/Hyperbee.AsyncExpressions/StateMachineBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ private void EmitBlockFields( BlockExpression block )
{
var expr = block.Expressions[i];

if ( !TryGetAwaiterType( expr, out Type awaiterType ) )
if ( !TryMakeAwaiterType( expr, out Type awaiterType ) )
continue; // Not an awaitable expression

var fieldName = $"_awaiter_{i}"; // `i` should match the index of the expression to align with state machine logic
Expand Down Expand Up @@ -433,7 +433,7 @@ private LambdaExpression CreateMoveNextExpression( BlockExpression block )
? Expression.Assign(
Expression.Field( stateMachineInstance, finalResultFieldInfo ),
Expression.Call(
Expression.Field( stateMachineInstance, lastAwaitField ),
Expression.Field( stateMachineInstance, lastAwaitField! ),
"GetResult", Type.EmptyTypes
)
)
Expand Down Expand Up @@ -488,35 +488,35 @@ private static FieldInfo GetFieldInfo( Type runtimeType, FieldBuilder field )
return runtimeType.GetField( field.Name, BindingFlags.Instance | BindingFlags.Public )!;
}

private static bool TryGetAwaiterType( Expression expr, out Type awaiterType )
private static bool TryMakeAwaiterType( Expression expr, out Type awaiterType )
{
awaiterType = null;

switch ( expr )
{
case MethodCallExpression methodCall when typeof( Task ).IsAssignableFrom( methodCall.Type ):
awaiterType = GetAwaiterType( methodCall.Type );
awaiterType = MakeAwaiterType( methodCall.Type );
return true;

case InvocationExpression invocation when typeof( Task ).IsAssignableFrom( invocation.Type ):
awaiterType = GetAwaiterType( invocation.Type );
awaiterType = MakeAwaiterType( invocation.Type );
return true;

case BlockExpression block:
return TryGetAwaiterType( block.Expressions.Last(), out awaiterType );
return TryMakeAwaiterType( block.Expressions.Last(), out awaiterType );

case AwaitExpression await:
awaiterType = GetAwaiterType( await.AsyncExpression.Type );
awaiterType = MakeAwaiterType( await.Target.Type );
return true;

case not null when typeof( Task ).IsAssignableFrom( expr.Type ):
awaiterType = GetAwaiterType( expr.Type );
awaiterType = MakeAwaiterType( expr.Type );
return true;
}

return false;

static Type GetAwaiterType( Type taskType )
static Type MakeAwaiterType( Type taskType )
{
if ( !taskType.IsGenericType )
return typeof( ConfiguredTaskAwaitable.ConfiguredTaskAwaiter );
Expand Down
4 changes: 2 additions & 2 deletions test/Hyperbee.AsyncExpressions.Tests/AsyncBlockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void TestAsyncBlock_SimpleBlockSplitting()
}

[TestMethod]
public async Task TestAsyncBlock_StartStateMachine()
public async Task TestAsyncBlock_WithoutParameters_ReturnsResult()
{
// Arrange
var expr1 = Expression.Constant( 1 );
Expand All @@ -119,7 +119,7 @@ public async Task TestAsyncBlock_StartStateMachine()
}

[TestMethod]
public async Task TestAsyncBlock_StartStateMachineWithVariables()
public async Task TestAsyncBlock_WithParameters_ReturnsResult()
{
// Arrange
var param1 = Expression.Parameter( typeof( int ), "param1" );
Expand Down
12 changes: 6 additions & 6 deletions test/Hyperbee.AsyncExpressions.Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private static async Task<int> GetNumberAsync()

private static async Task<int> AddTwoNumbersAsync( int a, int b )
{
await Task.Delay( 10 );
await Task.Delay( 100 );
return a + b;
}

Expand All @@ -37,7 +37,7 @@ private static async Task<int> AddThreeNumbersAsync( int a, int b, int c )

private static async Task<string> SayHelloAsync( int a )
{
await Task.Delay( 10 );
await Task.Delay( 100 );
return $"Hello {a}";
}

Expand All @@ -48,7 +48,7 @@ private static int IncrementValue( int a )

private static async Task<int> ThrowExceptionAsync()
{
await Task.Delay( 50 );
await Task.Delay( 100 );
throw new InvalidOperationException( "Simulated exception." );
}

Expand Down Expand Up @@ -336,8 +336,8 @@ public async Task TestMultipleAsyncExpressions_WithDeepNestingAsyncAwait()

var paramA = Expression.Parameter( typeof( int ), "a" );

var l1 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, paramA ), false );
var l2 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, l1 ), false );
var l1 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, paramA ), configureAwait: false );
var l2 = AsyncExpression.Await( AsyncExpression.InvokeAsync( incrementExpression, l1 ), configureAwait: false );
var l3 = AsyncExpression.InvokeAsync( incrementExpression, l2 );

var compiled = Expression.Lambda<Func<int, Task<int>>>( l3, paramA ).Compile();
Expand Down Expand Up @@ -412,7 +412,7 @@ public void TestAsyncExpression_ExceptionHandling( ExpressionKind kind )
try
{
_ = compiledLambda(); // Directly get the unwrapped result
Assert.Fail( "Expected exception was not thrown." );
Assert.Fail( "An exception was not thrown." );
}
catch ( InvalidOperationException ex )
{
Expand Down

0 comments on commit 9a90f4d

Please sign in to comment.