Created
May 7, 2020 11:13
-
-
Save neuecc/9cef2b79a0828796306fd835a0641fbb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Threading; | |
namespace Cysharp.Threading.Tasks | |
{ | |
public interface IUniTaskAsyncEnumerable<out T> | |
{ | |
IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default); | |
} | |
public interface IUniTaskAsyncEnumerator<out T> : IUniTaskAsyncDisposable | |
{ | |
T Current { get; } | |
UniTask<bool> MoveNextAsync(); | |
} | |
public interface IUniTaskAsyncDisposable | |
{ | |
UniTask DisposeAsync(); | |
} | |
public static class UniTaskAsyncEnumerable | |
{ | |
public static IUniTaskAsyncEnumerable<TResult> Select<TSource, TResult>(this IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector) | |
{ | |
return new Cysharp.Threading.Tasks.Linq.Select<TSource, TResult>(source, selector); | |
} | |
} | |
} | |
namespace Cysharp.Threading.Tasks.Linq | |
{ | |
public abstract class AsyncEnumeratorBase<TSource, TResult> : IUniTaskAsyncEnumerator<TResult>, IUniTaskSource<bool> | |
{ | |
static Action<object> moveNextCallbackDelegate = MoveNextCallBack; | |
readonly IUniTaskAsyncEnumerable<TSource> source; | |
CancellationToken cancellationToken; | |
UniTaskCompletionSourceCore<bool> completionSource; | |
IUniTaskAsyncEnumerator<TSource> enumerator; | |
UniTask<bool>.Awaiter sourceMoveNext; | |
public AsyncEnumeratorBase(IUniTaskAsyncEnumerable<TSource> source, CancellationToken cancellationToken) | |
{ | |
this.source = source; | |
this.cancellationToken = cancellationToken; | |
} | |
// abstract | |
protected abstract bool MoveNextCore(bool sourceHasCurrent); | |
// Util | |
protected TSource SourceCurrent => enumerator.Current; | |
// IUniTaskAsyncEnumerator<T> | |
public TResult Current { get; protected set; } | |
public UniTask<bool> MoveNextAsync() | |
{ | |
completionSource.Reset(); | |
if (enumerator == null) | |
{ | |
enumerator = source.GetAsyncEnumerator(cancellationToken); | |
} | |
sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter(); | |
if (sourceMoveNext.IsCompleted) | |
{ | |
bool result = false; | |
try | |
{ | |
result = MoveNextCore(sourceMoveNext.GetResult()); | |
} | |
catch (Exception ex) | |
{ | |
completionSource.TrySetException(ex); | |
goto RETURN; | |
} | |
if (cancellationToken.IsCancellationRequested) | |
{ | |
completionSource.TrySetCanceled(cancellationToken); | |
} | |
else | |
{ | |
completionSource.TrySetResult(result); | |
} | |
} | |
else | |
{ | |
sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this); | |
} | |
RETURN: | |
return new UniTask<bool>(this, completionSource.Version); | |
} | |
static void MoveNextCallBack(object state) | |
{ | |
var self = (AsyncEnumeratorBase<TSource, TResult>)state; | |
bool result; | |
try | |
{ | |
result = self.MoveNextCore(self.sourceMoveNext.GetResult()); | |
} | |
catch (Exception ex) | |
{ | |
self.completionSource.TrySetException(ex); | |
return; | |
} | |
if (self.cancellationToken.IsCancellationRequested) | |
{ | |
self.completionSource.TrySetCanceled(self.cancellationToken); | |
} | |
else | |
{ | |
self.completionSource.TrySetResult(result); | |
} | |
} | |
// if require additional resource to dispose, override and call base.DisposeAsync. | |
public virtual UniTask DisposeAsync() | |
{ | |
if (enumerator != null) | |
{ | |
return enumerator.DisposeAsync(); | |
} | |
return default; | |
} | |
// IUniTaskSource<bool> | |
public bool GetResult(short token) | |
{ | |
return completionSource.GetResult(token); | |
} | |
public UniTaskStatus GetStatus(short token) | |
{ | |
return completionSource.GetStatus(token); | |
} | |
public void OnCompleted(Action<object> continuation, object state, short token) | |
{ | |
completionSource.OnCompleted(continuation, state, token); | |
} | |
public UniTaskStatus UnsafeGetStatus() | |
{ | |
return completionSource.UnsafeGetStatus(); | |
} | |
void IUniTaskSource.GetResult(short token) | |
{ | |
completionSource.GetResult(token); | |
} | |
} | |
internal class Select<TSource, TResult> : IUniTaskAsyncEnumerable<TResult> | |
{ | |
readonly IUniTaskAsyncEnumerable<TSource> source; | |
readonly Func<TSource, TResult> selector; | |
public Select(IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector) | |
{ | |
this.source = source; | |
this.selector = selector; | |
} | |
public IUniTaskAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellationToken = default) | |
{ | |
return new Enumerator(source, selector, cancellationToken); | |
} | |
class Enumerator : AsyncEnumeratorBase<TSource, TResult> | |
{ | |
readonly Func<TSource, TResult> selector; | |
public Enumerator(IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector, CancellationToken cancellationToken) | |
: base(source, cancellationToken) | |
{ | |
this.selector = selector; | |
} | |
protected override bool MoveNextCore(bool sourceHasCurrent) | |
{ | |
if (sourceHasCurrent) | |
{ | |
Current = selector(SourceCurrent); | |
return true; | |
} | |
else | |
{ | |
return false; | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment