Skip to content

Commit

Permalink
Correct MoveNext state machine generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bfarmer67 committed Sep 9, 2024
1 parent 7d29926 commit 1f057e7
Show file tree
Hide file tree
Showing 7 changed files with 788 additions and 146 deletions.
18 changes: 13 additions & 5 deletions src/Hyperbee.AsyncExpressions/AsyncInvocationExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,29 @@ public override Expression Reduce()
if ( _isReduced )
return _stateMachine;

_stateMachine = StateMachineBuilder.Create( Block( _invocationExpression ), Type, createRunner: true );
var resultType = ResultType( _invocationExpression.Type );

_stateMachine = StateMachineBuilder.Create( Block( _invocationExpression ), resultType, createRunner: true );
_isReduced = true;

return _stateMachine;

static Type ResultType( Type returnType )
{
return returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>)
? returnType.GetGenericArguments()[0]
: typeof(void);
}
}

public override Type Type
{
get
{
var returnType = _invocationExpression.Type;
if ( !_isReduced )
Reduce();

return IsTask( returnType ) && returnType.IsGenericType
? returnType.GetGenericArguments()[0]
: typeof(void);
return _stateMachine.Type;
}
}
}
Expand Down
20 changes: 14 additions & 6 deletions src/Hyperbee.AsyncExpressions/AsyncMethodCallExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@ public override Expression Reduce()
if ( _isReduced )
return _stateMachine;

_stateMachine = StateMachineBuilder.Create( Block( _methodCallExpression ), Type, createRunner: true );
var resultType = ResultType( _methodCallExpression.Type );

_stateMachine = StateMachineBuilder.Create( Block( _methodCallExpression ), resultType, createRunner: true );
_isReduced = true;

return _stateMachine;
}

static Type ResultType( Type returnType )
{
return returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>)
? returnType.GetGenericArguments()[0]
: typeof(void);
}
}

public override Type Type
{
get
{
var returnType = _methodCallExpression.Type;
if ( !_isReduced )
Reduce();

return IsTask( returnType ) && returnType.IsGenericType
? returnType.GetGenericArguments()[0]
: typeof(void);
return _stateMachine.Type;
}
}
}
Expand Down
35 changes: 17 additions & 18 deletions src/Hyperbee.AsyncExpressions/AwaitExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public class AwaitExpression : Expression
{
private readonly Expression _asyncExpression;
private readonly bool _configureAwait;
private readonly Type _resultType;

private static readonly MethodInfo AwaitMethod = typeof(AwaitExpression).GetMethod( nameof(Await), BindingFlags.NonPublic | BindingFlags.Static );
private static readonly MethodInfo AwaitResultMethod = typeof(AwaitExpression).GetMethod( nameof(AwaitResult), BindingFlags.NonPublic | BindingFlags.Static );
Expand All @@ -18,25 +19,25 @@ internal AwaitExpression( Expression asyncExpression, bool configureAwait )
{
_asyncExpression = asyncExpression ?? throw new ArgumentNullException( nameof( asyncExpression ) );
_configureAwait = configureAwait;
_resultType = ResultType( asyncExpression.Type );
}


public override ExpressionType NodeType => ExpressionType.Extension;

public override Type Type
public override Type Type => _resultType;

private Type ResultType( Type taskType )
{
get
if ( ReturnTask )
return taskType;

return taskType.IsGenericType switch
{
if ( ReturnTask )
return _asyncExpression.Type;

return _asyncExpression.Type.IsGenericType switch
{
true when _asyncExpression.Type.GetGenericTypeDefinition() == typeof(Task<>) => _asyncExpression.Type.GetGenericArguments()[0],
false => typeof(void),
_ => throw new InvalidOperationException( $"Unsupported type in {nameof(AwaitExpression)}." )
};
}
true when taskType.GetGenericTypeDefinition() == typeof(Task<>) => taskType.GetGenericArguments()[0],
false => typeof(void),
_ => throw new InvalidOperationException( $"Unsupported type in {nameof(AwaitExpression)}." )
};
}

public bool ReturnTask { get; set; }
Expand All @@ -48,13 +49,11 @@ public override Expression Reduce()
if ( ReturnTask )
return _asyncExpression;

// BF - state machine is not being started (code was lost)

var awaitResult = Call( Type == typeof( void )
var awaitExpression = Call( _resultType == typeof( void )
? AwaitMethod
: AwaitResultMethod.MakeGenericMethod( Type ), _asyncExpression, Constant( _configureAwait ) );
: AwaitResultMethod.MakeGenericMethod( _resultType ), _asyncExpression, Constant( _configureAwait ) );

return awaitResult;
return awaitExpression;
}

private static void Await( Task task, bool configureAwait )
Expand All @@ -70,7 +69,7 @@ private static T AwaitResult<T>( Task<T> task, bool configureAwait )
private class AwaitExpressionProxy( AwaitExpression node )
{
public Expression Target => node._asyncExpression;
public Type ReturnType => node.Type;
public Type ReturnType => node._resultType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<None Include="..\..\assets\icon.png" Pack="true" Visible="false" PackagePath="/" />
<None Include="..\..\README.md" Pack="true" Visible="true" PackagePath="/" Link="README.md" />
<None Include="..\..\LICENSE" Pack="true" Visible="false" PackagePath="/" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.11.0" />
<PackageReference Update="Microsoft.SourceLink.GitHub" Version="8.0.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Loading

0 comments on commit 1f057e7

Please sign in to comment.