Last active
July 5, 2019 21:26
-
-
Save aelij/5d046b86bfca13fb682c411852d08cfd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.ComponentModel; | |
using System.IO; | |
using System.Linq; | |
using System.Runtime.CompilerServices; | |
using System.Security; | |
using System.Threading; | |
using System.Threading.Tasks; | |
namespace AsyncEnumeratorGenerator | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
var task = Run(); | |
try | |
{ | |
task.Wait(); | |
} | |
catch (Exception ex) | |
{ | |
Console.WriteLine(task.Status); | |
Console.WriteLine(ex); | |
} | |
} | |
private static async Task Run() | |
{ | |
// run in parallel | |
var enumerable = GetValuesAsync(); | |
await Task.WhenAll(Enumerable.Range(0, 10).Select(i => Print(enumerable))); | |
// cancellation | |
var cts = new CancellationTokenSource(); | |
cts.CancelAfter(100); | |
await Print(Where(Where(AsyncEnumerable.Range(0, 1000000), i => i > 100, cts.Token), i => i % 2 == 0, cts.Token)); | |
} | |
private static async Task Print(AsyncEnumerable<int> enumerable) | |
{ | |
var enumerator = enumerable.GetEnumerator(); | |
while (await enumerator.MoveNext()) | |
{ | |
Console.WriteLine(enumerator.Current); | |
} | |
} | |
public static AsyncEnumerable<int> GetValuesAsync() | |
{ | |
return new AsyncEnumerable<int>(GetValuesAsyncEnumerator); | |
} | |
private static async AsyncEnumerator<int> GetValuesAsyncEnumerator() | |
{ | |
for (int i = 0; i < 5; i++) | |
{ | |
await Task.Delay(200); | |
await Task.Run(() => i).YieldReturn(); | |
} | |
return default(int); // dummy value | |
} | |
public static AsyncEnumerable<T> Where<T>(IAsyncEnumerable<T> enumerable, Func<T, bool> predicate, CancellationToken cancellationToken) | |
{ | |
return new AsyncEnumerable<T>(() => WhereEnumerator<T>(enumerable, predicate, cancellationToken)); | |
} | |
private static async AsyncEnumerator<T> WhereEnumerator<T>(IAsyncEnumerable<T> enumerable, Func<T, bool> predicate, CancellationToken cancellationToken) | |
{ | |
var enumerator = enumerable.GetEnumerator(); | |
while (await enumerator.MoveNext(cancellationToken)) | |
{ | |
if (predicate(enumerator.Current)) | |
{ | |
// we can use a value task or a simple value awaitable here instead | |
await Task.FromResult(enumerator.Current).YieldReturn(); | |
} | |
} | |
return default(T); | |
} | |
} | |
public static class YieldReturnExtensions | |
{ | |
public static YieldReturnAwaitable<T> YieldReturn<T>(this Task<T> task) | |
{ | |
return new YieldReturnAwaitable<T>(task); | |
} | |
} | |
public struct YieldReturnAwaitable<TResult> | |
{ | |
private readonly Task<TResult> _task; | |
public YieldReturnAwaitable(Task<TResult> task) | |
{ | |
_task = task; | |
} | |
public YieldReturnAwaiter GetAwaiter() => new YieldReturnAwaiter(_task); | |
public struct YieldReturnAwaiter : ICriticalNotifyCompletion, INotifyCompletion | |
{ | |
private readonly TaskAwaiter<TResult> _awaiter; | |
// if we return true, the state machine would skip calling AwaitOnCompleted | |
// and we won't be able to yield | |
// instead we check IsCompletedInternal in the Await method to optimize the continuation | |
public bool IsCompleted => false; | |
internal bool IsCompletedInternal => _awaiter.IsCompleted; | |
internal YieldReturnAwaiter(Task<TResult> task) | |
{ | |
_awaiter = task.GetAwaiter(); | |
} | |
[SecuritySafeCritical] | |
public void OnCompleted(Action continuation) => _awaiter.OnCompleted(continuation); | |
[SecurityCritical] | |
public void UnsafeOnCompleted(Action continuation) => _awaiter.UnsafeOnCompleted(continuation); | |
public TResult GetResult() => _awaiter.GetResult(); | |
} | |
} | |
public sealed class AsyncEnumerable<T> : IAsyncEnumerable<T> | |
{ | |
private readonly Func<IAsyncEnumerator<T>> _getEnumerator; | |
public AsyncEnumerable(Func<IAsyncEnumerator<T>> getEnumerator) | |
{ | |
_getEnumerator = getEnumerator; | |
} | |
public IAsyncEnumerator<T> GetEnumerator() => _getEnumerator(); | |
} | |
public sealed class AsyncEnumerator<T> : IAsyncEnumerator<T> | |
{ | |
private static readonly Func<bool> _emptyFunc = () => false; | |
private AsyncEnumeratorTaskMethodBuilder<T> _builder; | |
internal AsyncEnumerator(AsyncEnumeratorTaskMethodBuilder<T> builder) | |
{ | |
_builder = builder; | |
} | |
public T Current => _builder._current; | |
public void Dispose() | |
{ | |
} | |
public Task<bool> MoveNext(CancellationToken cancellationToken) | |
{ | |
if (cancellationToken.IsCancellationRequested) | |
{ | |
// same as Task.FromCancelled | |
return new Task<bool>(_emptyFunc, cancellationToken); | |
} | |
else | |
{ | |
_builder._tcs = new TaskCompletionSource<bool>(); | |
_builder._stateMachine.MoveNext(); | |
return _builder._tcs.Task; | |
} | |
} | |
[EditorBrowsable(EditorBrowsableState.Never)] | |
public static AsyncEnumeratorTaskMethodBuilder<T> CreateAsyncMethodBuilder() => AsyncEnumeratorTaskMethodBuilder<T>.Create(); | |
} | |
public sealed class AsyncEnumeratorTaskMethodBuilder<T> | |
{ | |
private YieldReturnAwaitable<T>.YieldReturnAwaiter? _yieldReturnAwaiter; | |
internal IAsyncStateMachine _stateMachine; | |
internal T _current; | |
internal TaskCompletionSource<bool> _tcs; | |
public static AsyncEnumeratorTaskMethodBuilder<T> Create() => new AsyncEnumeratorTaskMethodBuilder<T>(); | |
public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine | |
{ | |
_stateMachine = stateMachine; | |
} | |
public void SetStateMachine(IAsyncStateMachine stateMachine) | |
{ | |
} | |
public void SetResult(T result) | |
{ | |
// ignore the result value | |
if (_tcs != null) | |
{ | |
_tcs.TrySetResult(false); | |
} | |
} | |
public void SetException(Exception exception) | |
{ | |
if (_tcs != null) | |
{ | |
if (exception is OperationCanceledException) | |
{ | |
_tcs.TrySetCanceled(); | |
} | |
else | |
{ | |
_tcs.TrySetException(exception); | |
} | |
} | |
} | |
public AsyncEnumerator<T> Task => new AsyncEnumerator<T>(this); | |
public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : INotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
Await(ref awaiter, ref stateMachine); | |
} | |
[SecuritySafeCritical] | |
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : ICriticalNotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
Await(ref awaiter, ref stateMachine); | |
} | |
private void Await<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : INotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
_yieldReturnAwaiter = awaiter as YieldReturnAwaitable<T>.YieldReturnAwaiter?; | |
if (_yieldReturnAwaiter?.IsCompletedInternal == true) | |
{ | |
InvokeMoveNext(); | |
} | |
else | |
{ | |
var runner = new MoveNextRunner(ExecutionContext.Capture(), _stateMachine, this); | |
awaiter.OnCompleted(() => runner.Run()); | |
} | |
} | |
internal void InvokeMoveNext() | |
{ | |
if (_yieldReturnAwaiter == null) | |
{ | |
// this is a "normal" await - just continue async execution | |
_stateMachine.MoveNext(); | |
return; | |
} | |
try | |
{ | |
// GetResult will be called again by the async state machine (oh, well :) | |
_current = _yieldReturnAwaiter.Value.GetResult(); | |
_tcs.TrySetResult(true); | |
} | |
catch (Exception ex) | |
{ | |
SetException(ex); | |
} | |
} | |
private sealed class MoveNextRunner | |
{ | |
private readonly ExecutionContext _context; | |
private readonly IAsyncStateMachine _stateMachine; | |
[SecurityCritical] | |
private static ContextCallback _invokeMoveNext; | |
private readonly AsyncEnumeratorTaskMethodBuilder<T> _builder; | |
[SecurityCritical] | |
internal MoveNextRunner(ExecutionContext context, IAsyncStateMachine stateMachine, AsyncEnumeratorTaskMethodBuilder<T> builder) | |
{ | |
_context = context; | |
_stateMachine = stateMachine; | |
_builder = builder; | |
} | |
[SecuritySafeCritical] | |
internal void Run() | |
{ | |
if (_context != null) | |
{ | |
try | |
{ | |
ContextCallback contextCallback = _invokeMoveNext; | |
if (contextCallback == null) | |
{ | |
contextCallback = (_invokeMoveNext = new ContextCallback(InvokeMoveNext)); | |
} | |
ExecutionContext.Run(_context, contextCallback, _builder); | |
return; | |
} | |
finally | |
{ | |
_context.Dispose(); | |
} | |
} | |
_builder.InvokeMoveNext(); | |
} | |
[SecurityCritical] | |
private static void InvokeMoveNext(object builder) | |
{ | |
((AsyncEnumeratorTaskMethodBuilder<T>)builder).InvokeMoveNext(); | |
} | |
} | |
} | |
} |
Updated version:
- The async method now generates an
AsyncEnumerator<T>
instead of an enumerable.- I've come to the conclusion it's not possible to really create a factory from the async method since the builder field is instantiated once inside the state machine and cannot be modified (it works as long as you don't try to execute the enumerators in parallel, but that's not an option IMO).
- So now we use two methods, one that the compiler generates - the enumerator, and another "factory" method - for the enumerable.
- Cancellation should now work as expected.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Stuff to review:
ExecutionContext
capturingGetEnumerator()
?YieldReturn()
on the wrong type? (there's no compile-time check)