Skip to content

Commit

Permalink
Clean and add block
Browse files Browse the repository at this point in the history
  • Loading branch information
MattEdwardsWaggleBee committed Aug 30, 2024
1 parent 356b3ac commit 0378ba9
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 50 deletions.
241 changes: 241 additions & 0 deletions src/Hyperbee.AsyncExpressions/AsyncBlockExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
using System.Diagnostics;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace Hyperbee.AsyncExpressions;

[DebuggerDisplay( "{_body}" )]
[DebuggerTypeProxy( typeof( AsyncBlockExpressionProxy ) )]
public class AsyncBlockExpression : Expression
{
private readonly Expression _body;
private Expression _reducedBody;
private bool _isReduced;
private static int _stateMachineCounter;

private static readonly Expression VoidResult = Constant(Task.FromResult(new VoidTaskResult()));

private static MethodInfo GenericGenerateExecuteAsync => typeof( AsyncInvokeExpression )
.GetMethod( nameof( GenerateExecuteAsyncExpression ), BindingFlags.Static | BindingFlags.NonPublic );

internal AsyncBlockExpression( Expression body )
{
ArgumentNullException.ThrowIfNull( body, nameof( body ) );

if ( !IsAsync( body.Type ) )
throw new ArgumentException( $"The specified {nameof( body )} is not an async.", nameof( body ) );

_body = body;
}


public override ExpressionType NodeType => ExpressionType.Extension;

public override Type Type => _body.Type;

public override bool CanReduce => true;

public override Expression Reduce()
{
if (_isReduced)
return _reducedBody;

_isReduced = true;

var (type, result) = GetTypeResult(_body);
var methodInfo = GenericGenerateExecuteAsync?.MakeGenericMethod(type);

_reducedBody = (Expression)methodInfo!.Invoke(null, [result]);

return _reducedBody!;
}

private static (Type Type, Expression Expression) GetTypeResult(Expression expression)
{
return expression.Type == typeof(Task)
? (typeof(VoidTaskResult), Block(expression, VoidResult))
: (expression.Type.GetGenericArguments()[0], expression);
}


private static BlockExpression GenerateExecuteAsyncExpression<T>( Expression task )
{
// Generating code block:
/*
internal static Task<T> ExecuteAsync<T>(Task<T> task)
{
var stateMachine = new StateMachine<T>(task);
stateMachine.MoveNext();
return stateMachine.Task;
}
*/

// Create unique variable names to avoid conflicts
var id = Interlocked.Increment( ref _stateMachineCounter );
var stateMachineVar = Variable( typeof(MultiTaskStateMachine<T> ), $"stateMachine_{id}" );

// Constructor for state machine
var stateMachineCtor = typeof(MultiTaskStateMachine<T> )
.GetConstructor( [typeof( Task<T> )] );

var assignStateMachine = Assign(
stateMachineVar,
New( stateMachineCtor!, task )
);

// Call MoveNext
var moveNextMethod = typeof(MultiTaskStateMachine<T> ).GetMethod( nameof(MultiTaskStateMachine<T>.MoveNext ) );
var moveNextCall = Call( stateMachineVar, moveNextMethod! );

// Return task property
var taskProperty = typeof(MultiTaskStateMachine<T> ).GetProperty( nameof(MultiTaskStateMachine<T>.Task ) );
var returnTask = Property( stateMachineVar, taskProperty! );

// Explicitly use nested blocks to handle variable scoping
var resultBlock = Block(
[stateMachineVar],
assignStateMachine,
moveNextCall,
returnTask
);

return resultBlock;
}

private struct MultiTaskStateMachine<T> : IAsyncStateMachine
{
private readonly Task[] _tasks;
private readonly bool _isLastTaskGeneric;
private AsyncTaskMethodBuilder<T> _builder;
private int _state;

public MultiTaskStateMachine( Task[] tasks )
{
_builder = AsyncTaskMethodBuilder<T>.Create();
_state = -1;
_tasks = tasks;

// Determine if the last task is generic or not
var lastTaskType = tasks[^1].GetType();
_isLastTaskGeneric = lastTaskType.IsGenericType && lastTaskType.GetGenericTypeDefinition() == typeof( Task<> );

SetStateMachine( this );
}

public Task<T> Task => _builder.Task;

public void MoveNext()
{
try
{
if ( _state == -1 )
{
// Initial state:
_state = 0;
}

if ( _state >= 0 && _state < _tasks.Length )
{
var currentTask = _tasks[_state];

if ( _state == _tasks.Length - 1 && _isLastTaskGeneric )
{
// Last task is generic
var genericAwaiter = ((Task<T>) currentTask).ConfigureAwait( false ).GetAwaiter();
if ( !genericAwaiter.IsCompleted )
{
_builder.AwaitUnsafeOnCompleted( ref genericAwaiter, ref this );
return;
}

// Get the result directly if the task is already completed
var result = genericAwaiter.GetResult();
_state = -2;
_builder.SetResult( result );
}
else
{
// Intermediate non-generic task or last non-generic task
var awaiter = currentTask.ConfigureAwait( false ).GetAwaiter();
if ( !awaiter.IsCompleted )
{
_builder.AwaitUnsafeOnCompleted( ref awaiter, ref this );
return;
}

// Continue directly if the task is already completed
awaiter.GetResult();
_state++;
MoveNext();
}
}
else if ( _state == _tasks.Length && !_isLastTaskGeneric )
{
// All tasks completed, last task was non-generic
_state = -2;
_builder.SetResult( default! );
}
}
catch ( Exception ex )
{
// Final state: error
_state = -2;
_builder.SetException( ex );
}
}

public void SetStateMachine( IAsyncStateMachine stateMachine )
{
_builder.SetStateMachine( stateMachine );
}
}

private static bool IsAsync( Type returnType )
{
return returnType == typeof( Task ) ||
(returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( Task<> )) ||
(returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( ValueTask<> ));
}

public class AsyncBlockExpressionProxy( AsyncBlockExpression node )
{
public Expression Body => node._body;
}

public static AsyncBlockExpression BlockAsync( BlockExpression expression )
{
//expression.Expressions.Count..

/*
{
var result1 = {
[ex1Task]
expression1, //Task Assign( ex1Task, expression1 )
expression2,
awaitExpression3 ( expression3 /// Expression ),
},
{
[ex1Task, result1]
await( ex1Task,void,T )
}
var result3 = {
[result2]
expression4,
}
...
}
*/

//var d = Task.Delay( 10 );
// ...
//await d;


return new AsyncBlockExpression( expression );
}

}
14 changes: 9 additions & 5 deletions src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using System.Reflection;

namespace Hyperbee.AsyncExpressions;

Expand All @@ -8,11 +9,14 @@ internal AsyncInvocationExpression( InvocationExpression body ) : base( body )
{
}

public static AsyncInvokeExpression InvokeAsync( LambdaExpression lambdaExpression, params Expression[] arguments )
}
public static partial class AsyncExpression
{
public static AsyncInvokeExpression InvokeAsync(LambdaExpression lambdaExpression, params Expression[] arguments)
{
if ( !IsAsync( lambdaExpression.ReturnType ) )
throw new ArgumentException( "The specified lambda is not an async.", nameof( lambdaExpression ) );
if (!AsyncInvokeExpression.IsAsync(lambdaExpression.ReturnType))
throw new ArgumentException("The specified lambda is not an async.", nameof(lambdaExpression));

return new AsyncInvokeExpression( Invoke( lambdaExpression, arguments ) );
return new AsyncInvokeExpression(Expression.Invoke(lambdaExpression, arguments));
}
}
}
26 changes: 13 additions & 13 deletions src/Hyperbee.AsyncExpressions/AsyncInvokeExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal AsyncInvokeExpression( Expression body )

public override ExpressionType NodeType => ExpressionType.Extension;

public override Type Type => _body.Type == typeof( Task ) ? typeof( Task<VoidTaskResult> ) : _body.Type;
public override Type Type => _body.Type;

public override bool CanReduce => true;

Expand All @@ -42,19 +42,19 @@ public override Expression Reduce()

_isReduced = true;

var (type, result) = GetTypeResult( _body );
var methodInfo = GenericGenerateExecuteAsync?.MakeGenericMethod( type );

_reducedBody = (Expression) methodInfo!.Invoke( null, [result] );
_reducedBody = _body.Type.IsGenericType switch
{
true => GetReduceBody(_body.Type.GetGenericArguments()[0], _body),
false => GetReduceBody( typeof(VoidTaskResult), Block( _body, VoidResult ) )
};

return _reducedBody!;
}
return _reducedBody;

private static (Type Type, Expression Expression) GetTypeResult( Expression expression )
{
return expression.Type == typeof( Task )
? (typeof( VoidTaskResult ), Block( expression, VoidResult ))
: (expression.Type.GetGenericArguments()[0], expression);
static Expression GetReduceBody( Type type, Expression body )
{
var methodInfo = GenericGenerateExecuteAsync.MakeGenericMethod( type );
return (Expression) methodInfo!.Invoke( null, [body] );
}
}

private static BlockExpression GenerateExecuteAsyncExpression<T>( Expression task )
Expand Down Expand Up @@ -167,7 +167,7 @@ public void SetStateMachine( IAsyncStateMachine stateMachine )
}
}

protected static bool IsAsync( Type returnType )
internal static bool IsAsync( Type returnType )
{
return returnType == typeof( Task ) ||
(returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof( Task<> )) ||
Expand Down
19 changes: 11 additions & 8 deletions src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@ public class AsyncMethodCallExpression : AsyncInvokeExpression
internal AsyncMethodCallExpression( MethodCallExpression body ) : base( body )
{
}
}


public static partial class AsyncExpression
{
public static AsyncInvokeExpression CallAsync( MethodInfo methodInfo, params Expression[] arguments )
{
if ( !IsAsync( methodInfo.ReturnType ) )
throw new ArgumentException( "The specified method is not an async.", nameof( methodInfo ) );
if ( !AsyncInvokeExpression.IsAsync( methodInfo.ReturnType ) )
throw new ArgumentException( "The specified method is not an async.", nameof(methodInfo) );

return new AsyncInvokeExpression( Call( methodInfo, arguments ) );
return new AsyncInvokeExpression( Expression.Call( methodInfo, arguments ) );
}

public static AsyncInvokeExpression CallAsync( Expression instance, MethodInfo methodInfo,
params Expression[] arguments )
{
if ( !IsAsync( methodInfo.ReturnType ) )
throw new ArgumentException( "The specified method is not an async.", nameof( methodInfo ) );
if ( !AsyncInvokeExpression.IsAsync( methodInfo.ReturnType ) )
throw new ArgumentException( "The specified method is not an async.", nameof(methodInfo) );

return new AsyncInvokeExpression( Call( instance, methodInfo, arguments ) );
return new AsyncInvokeExpression( Expression.Call( instance, methodInfo, arguments ) );
}
}

}
6 changes: 5 additions & 1 deletion src/Hyperbee.AsyncExpressions/AwaitExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ private static T Await<T>( Task<T> task, bool configureAwait )
return result;
}

}

public static partial class AsyncExpression
{
public static AwaitExpression Await(Expression expression, bool configureAwait)
{
return new AwaitExpression(expression, configureAwait);
}
}
}
Loading

0 comments on commit 0378ba9

Please sign in to comment.