diff --git a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs
new file mode 100644
index 0000000000..b37fa2f456
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs
@@ -0,0 +1,121 @@
+// ------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ------------------------------------------------------------
+namespace Microsoft.Azure.Cosmos.Pagination
+{
+ using System;
+ using System.Buffers;
+ using System.Collections.Generic;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos.Tracing;
+
+ ///
+ /// Holds the "just for testing"-bits of .
+ ///
+ internal static partial class ParallelPrefetch
+ {
+ ///
+ /// For testing purposes, provides ways to instrument .
+ ///
+ /// You shouldn't be using this outside of test projects.
+ ///
+ internal sealed class ParallelPrefetchTestConfig : ITrace
+ {
+ private ITrace innerTrace;
+
+ private int startedTasks;
+ private int awaitedTasks;
+
+ public ArrayPool PrefetcherPool { get; private set; }
+ public ArrayPool TaskPool { get; private set; }
+ public ArrayPool ObjectPool { get; private set; }
+
+ public int StartedTasks
+ => this.startedTasks;
+
+ public int AwaitedTasks
+ => this.awaitedTasks;
+
+ string ITrace.Name => this.innerTrace.Name;
+
+ Guid ITrace.Id => this.innerTrace.Id;
+
+ DateTime ITrace.StartTime => this.innerTrace.StartTime;
+
+ TimeSpan ITrace.Duration => this.innerTrace.Duration;
+
+ TraceLevel ITrace.Level => this.innerTrace.Level;
+
+ TraceComponent ITrace.Component => this.innerTrace.Component;
+
+ TraceSummary ITrace.Summary => this.innerTrace.Summary;
+
+ ITrace ITrace.Parent => this.innerTrace.Parent;
+
+ IReadOnlyList ITrace.Children => this.innerTrace.Children;
+
+ IReadOnlyDictionary ITrace.Data => this.innerTrace.Data;
+
+ public ParallelPrefetchTestConfig(
+ ArrayPool prefetcherPool,
+ ArrayPool taskPool,
+ ArrayPool objectPool)
+ {
+ this.PrefetcherPool = prefetcherPool;
+ this.TaskPool = taskPool;
+ this.ObjectPool = objectPool;
+ }
+
+ public void SetInnerTrace(ITrace trace)
+ {
+ this.innerTrace = trace;
+ }
+
+ public void TaskStarted()
+ {
+ Interlocked.Increment(ref this.startedTasks);
+ }
+
+ public void TaskAwaited()
+ {
+ Interlocked.Increment(ref this.awaitedTasks);
+ }
+
+ ITrace ITrace.StartChild(string name)
+ {
+ return this.innerTrace.StartChild(name);
+ }
+
+ ITrace ITrace.StartChild(string name, TraceComponent component, TraceLevel level)
+ {
+ return this.innerTrace.StartChild(name, component, level);
+ }
+
+ void ITrace.AddDatum(string key, TraceDatum traceDatum)
+ {
+ this.innerTrace.AddDatum(key, traceDatum);
+ }
+
+ void ITrace.AddDatum(string key, object value)
+ {
+ this.innerTrace.AddDatum(key, value);
+ }
+
+ void ITrace.AddOrUpdateDatum(string key, object value)
+ {
+ this.innerTrace.AddOrUpdateDatum(key, value);
+ }
+
+ void ITrace.AddChild(ITrace trace)
+ {
+ this.innerTrace.AddChild(trace);
+ }
+
+ void IDisposable.Dispose()
+ {
+ this.innerTrace.Dispose();
+ }
+ }
+ }
+}
diff --git a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs
index 56bc732658..c4782329f7 100644
--- a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs
+++ b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs
@@ -5,73 +5,747 @@
namespace Microsoft.Azure.Cosmos.Pagination
{
using System;
+ using System.Buffers;
using System.Collections.Generic;
+ using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
- internal static class ParallelPrefetch
+ internal static partial class ParallelPrefetch
{
- public static async Task PrefetchInParallelAsync(
+ ///
+ /// Number of tasks started at one time, maximum, when working through prefetchers.
+ ///
+ /// Also used as a the limit between Low and High concurrency implementations.
+ ///
+ /// This number should be reasonable large, but less than the point where a
+ /// Task[BatchLimit] ends up on the LOH (which will be around 8,192).
+ ///
+ private const int BatchLimit = 512;
+
+ ///
+ /// Common state that is needed for all tasks started via , unless
+ /// certain special cases hold.
+ ///
+ /// Also used as a synchronization primitive.
+ ///
+ private sealed class CommonPrefetchState
+ {
+ // we also use this to signal if we're finished enumerating, to save space
+ private IEnumerator enumerator;
+
+ ///
+ /// If this is true, it's a signal that new work should not be queued up.
+ ///
+ public bool FinishedEnumerating
+ => Volatile.Read(ref this.enumerator) == null;
+
+ ///
+ /// Common to be used by all tasks.
+ ///
+ /// When testing, this can also include be a .
+ /// We reuse this to keep allocations down in non-test cases.
+ ///
+ public ITrace PrefetchTrace { get; private set; }
+
+ ///
+ /// The which will produce the next
+ /// to use.
+ ///
+ /// Once at least one Task been started, should only be accessed under a lock.
+ ///
+ /// If == true, this returns null.
+ ///
+ public IEnumerator Enumerator
+ => Volatile.Read(ref this.enumerator);
+
+ ///
+ /// provided via .
+ ///
+ public CancellationToken CancellationToken { get; private set; }
+
+ public CommonPrefetchState(ITrace prefetchTrace, IEnumerator enumerator, CancellationToken cancellationToken)
+ {
+ this.PrefetchTrace = prefetchTrace;
+ this.enumerator = enumerator;
+ this.CancellationToken = cancellationToken;
+ }
+
+ ///
+ /// Cause to return true.
+ ///
+ public void SetFinishedEnumerating()
+ {
+ Volatile.Write(ref this.enumerator, null);
+ }
+ }
+
+ ///
+ /// State passed when we start a Task with an initial .
+ ///
+ /// That started Task will obtain it's next IPrefetchers using the
+ /// that is also provided.
+ ///
+ private sealed class SinglePrefetchState
+ {
+ ///
+ /// State common to the whole call.
+ ///
+ public CommonPrefetchState CommonState { get; private set; }
+
+ ///
+ /// which must be invoked next.
+ ///
+ public IPrefetcher CurrentPrefetcher { get; set; }
+
+ public SinglePrefetchState(CommonPrefetchState commonState, IPrefetcher initialPrefetcher)
+ {
+ this.CommonState = commonState;
+ this.CurrentPrefetcher = initialPrefetcher;
+ }
+ }
+
+ public static Task PrefetchInParallelAsync(
+ IEnumerable prefetchers,
+ int maxConcurrency,
+ ITrace trace,
+ CancellationToken cancellationToken)
+ {
+ prefetchers = prefetchers ?? throw new ArgumentNullException(nameof(prefetchers));
+ trace = trace ?? throw new ArgumentNullException(nameof(trace));
+
+ return PrefetchInParallelCoreAsync(prefetchers, maxConcurrency, trace, null, cancellationToken);
+ }
+
+ ///
+ /// Exposed for testing purposes, do not call directly.
+ ///
+ public static Task PrefetchInParallelCoreAsync(
IEnumerable prefetchers,
int maxConcurrency,
ITrace trace,
+ ParallelPrefetchTestConfig config,
CancellationToken cancellationToken)
{
- if (prefetchers == null)
+ if (maxConcurrency <= 0)
+ {
+ // old code would just... allocate and then do nothing
+ //
+ // so we do nothing here, for compatability purposes
+ return Task.CompletedTask;
+ }
+ else if (maxConcurrency == 1)
+ {
+ return SingleConcurrencyPrefetchInParallelAsync(prefetchers, trace, config, cancellationToken);
+ }
+ else if (maxConcurrency <= BatchLimit)
{
- throw new ArgumentNullException(nameof(prefetchers));
+ return LowConcurrencyPrefetchInParallelAsync(prefetchers, maxConcurrency, trace, config, cancellationToken);
}
+ else
+ {
+ return HighConcurrencyPrefetchInParallelAsync(prefetchers, maxConcurrency, trace, config, cancellationToken);
+ }
+ }
+
+ ///
+ /// Shared code for starting traces while prefetching.
+ ///
+ private static ITrace CommonStartTrace(ITrace trace)
+ {
+ return trace.StartChild(name: "Prefetching", TraceComponent.Pagination, TraceLevel.Info);
+ }
- if (trace == null)
+ ///
+ /// Helper for grabbing a reusable array.
+ ///
+ private static T[] RentArray(ParallelPrefetchTestConfig config, int minSize, bool clear)
+ {
+ T[] result;
+ if (config != null)
{
- throw new ArgumentNullException(nameof(trace));
+#pragma warning disable IDE0045 // Convert to conditional expression - chained else if is clearer
+ if (typeof(T) == typeof(IPrefetcher))
+ {
+ result = (T[])(object)config.PrefetcherPool.Rent(minSize);
+ }
+ else if (typeof(T) == typeof(Task))
+ {
+ result = (T[])(object)config.TaskPool.Rent(minSize);
+ }
+ else
+ {
+ result = (T[])(object)config.ObjectPool.Rent(minSize);
+ }
+#pragma warning restore IDE0045
+ }
+ else
+ {
+ result = ArrayPool.Shared.Rent(minSize);
}
- using (ITrace prefetchTrace = trace.StartChild(name: "Prefetching", TraceComponent.Pagination, TraceLevel.Info))
+ if (clear)
{
- HashSet tasks = new HashSet();
- IEnumerator prefetchersEnumerator = prefetchers.GetEnumerator();
- for (int i = 0; i < maxConcurrency; i++)
+ Array.Clear(result, 0, result.Length);
+ }
+
+ return result;
+ }
+
+ ///
+ /// Helper for returning arrays what were rented via .
+ ///
+ private static void ReturnRentedArray(ParallelPrefetchTestConfig config, T[] array, int clearThrough)
+ {
+ if (array == null)
+ {
+ return;
+ }
+
+ // this is important, otherwise we might leave Tasks and IPrefetchers
+ // rooted long enough to cause problems
+ Array.Clear(array, 0, clearThrough);
+
+ if (config != null)
+ {
+ if (typeof(T) == typeof(IPrefetcher))
{
- if (!prefetchersEnumerator.MoveNext())
+ config.PrefetcherPool.Return((IPrefetcher[])(object)array);
+ }
+ else if (typeof(T) == typeof(Task))
+ {
+ config.TaskPool.Return((Task[])(object)array);
+ }
+ else
+ {
+ config.ObjectPool.Return((object[])(object)array);
+ }
+ }
+ else
+ {
+ ArrayPool.Shared.Return(array);
+ }
+ }
+
+ ///
+ /// Starts a new Task that first calls on the passed
+ /// , and then grabs new ones from and repeats the process
+ /// until either the enumerator finishes or something sets .
+ ///
+ private static Task CommonStartTaskAsync(ParallelPrefetchTestConfig config, CommonPrefetchState commonState, IPrefetcher firstPrefetcher)
+ {
+ config?.TaskStarted();
+
+ SinglePrefetchState state = new (commonState, firstPrefetcher);
+
+ // this is mimicing the behavior of Task.Run(...) (that is, default CancellationToken, default Scheduler, DenyAttachChild, etc.)
+ // but in a way that let's us pass a context object
+ //
+ // this lets us declare a static delegate, and thus let's compiler reuse the delegate allocation
+ Task taskLoop =
+ Task.Factory.StartNew(
+ static async (context) =>
{
- break;
- }
+ // this method is structured a bit oddly to prevent the compiler from putting more data into the
+ // state of the Task - basically, don't have any locals (except context) that survive across an await
+ //
+ // we could go harder here and just not use async/await but that's awful for maintainability
+ try
+ {
+ while (true)
+ {
+ // step up to the initial await
+ {
+ SinglePrefetchState innerState = (SinglePrefetchState)context;
+
+ CommonPrefetchState innerCommonState = innerState.CommonState;
+ (ITrace prefetchTrace, CancellationToken cancellationToken) = (innerCommonState.PrefetchTrace, innerCommonState.CancellationToken);
+
+ // we smuggle a test config in as the common ITrace
+ //
+ // in most code, this will be null - but this pattern
+ // let's use keep CommonPrefetchState small
+ ParallelPrefetchTestConfig config = prefetchTrace as ParallelPrefetchTestConfig;
+
+ config?.TaskStarted();
+ config?.TaskAwaited();
+ await innerState.CurrentPrefetcher.PrefetchAsync(prefetchTrace, cancellationToken);
+ }
+
+ // step for preparing the next prefetch
+ {
+ SinglePrefetchState innerState = (SinglePrefetchState)context;
- IPrefetcher prefetcher = prefetchersEnumerator.Current;
- tasks.Add(Task.Run(async () => await prefetcher.PrefetchAsync(prefetchTrace, cancellationToken)));
+ CommonPrefetchState innerCommonState = innerState.CommonState;
+
+ if (innerCommonState.FinishedEnumerating)
+ {
+ // we're done, bail
+ return;
+ }
+
+ // proceed to the next item
+ //
+ // we need this lock because at this point there
+ // are other Tasks potentially also looking to call
+ // enumerator.MoveNext()
+ lock (innerCommonState)
+ {
+ // this can have transitioned to null since we last checked
+ // so this is basically double-check locking
+ IEnumerator enumerator = innerCommonState.Enumerator;
+ if (enumerator == null)
+ {
+ return;
+ }
+
+ if (!enumerator.MoveNext())
+ {
+ // we're done, signal to every other task to also bail
+ innerCommonState.SetFinishedEnumerating();
+
+ return;
+ }
+
+ // move on to the new IPrefetcher just obtained
+ innerState.CurrentPrefetcher = enumerator.Current;
+ }
+ }
+ }
+ }
+ catch
+ {
+ SinglePrefetchState innerState = (SinglePrefetchState)context;
+
+ // some error was encountered, we should tell other tasks to stop starting new prefetch tasks
+ // because we're about to cancel
+ innerState.CommonState.SetFinishedEnumerating();
+
+ // percolate the error up
+ throw;
+ }
+ },
+ state,
+ default,
+ TaskCreationOptions.DenyChildAttach,
+ TaskScheduler.Default);
+
+ // we _could_ maybe optimize this more... perhaps using a SemaphoreSlim or something
+ // but that complicates error reporting and is also awful for maintability
+ Task unwrapped = taskLoop.Unwrap();
+
+ return unwrapped;
+ }
+
+ ///
+ /// Fills a portion of an IPrefetcher[] using the passed enumerator.
+ ///
+ /// Returns the index that would next be filled.
+ ///
+ /// Updates the passed if the end of the enumerator is reached.
+ ///
+ /// Synchronization is the concern of the caller, not this method.
+ ///
+ private static int FillPrefetcherBuffer(CommonPrefetchState commonState, IPrefetcher[] prefetchers, int startIndex, int endIndex, IEnumerator enumerator)
+ {
+ int curIndex;
+ for (curIndex = startIndex; curIndex < endIndex; curIndex++)
+ {
+ if (!enumerator.MoveNext())
+ {
+ commonState.SetFinishedEnumerating();
+ break;
}
- while (tasks.Count != 0)
+ prefetchers[curIndex] = enumerator.Current;
+ }
+
+ return curIndex;
+ }
+
+ ///
+ /// Special case for when maxConcurrency == 1.
+ ///
+ /// This devolves into a foreach loop.
+ ///
+ private static async Task SingleConcurrencyPrefetchInParallelAsync(IEnumerable prefetchers, ITrace trace, ParallelPrefetchTestConfig config, CancellationToken cancellationToken)
+ {
+ using (ITrace prefetchTrace = CommonStartTrace(trace))
+ {
+ foreach (IPrefetcher prefetcher in prefetchers)
{
- Task completedTask = await Task.WhenAny(tasks);
- tasks.Remove(completedTask);
- try
- {
- await completedTask;
- }
- catch
+ config?.TaskStarted();
+ config?.TaskAwaited();
+ await prefetcher.PrefetchAsync(prefetchTrace, cancellationToken);
+ }
+ }
+ }
+
+ ///
+ /// The case where maxConcurrency is less than or equal to BatchLimit.
+ ///
+ /// This starts up to maxConcurrency simultanous Tasks, doing so in a way that
+ /// requires rented arrays of maxConcurrency size.
+ ///
+ private static async Task LowConcurrencyPrefetchInParallelAsync(
+ IEnumerable prefetchers,
+ int maxConcurrency,
+ ITrace trace,
+ ParallelPrefetchTestConfig config,
+ CancellationToken cancellationToken)
+ {
+ IPrefetcher[] initialPrefetchers = null;
+ Task[] runningTasks = null;
+
+ int nextPrefetcherIndex = 0;
+ int nextRunningTaskIndex = 0;
+
+ try
+ {
+ using (ITrace prefetchTrace = CommonStartTrace(trace))
+ {
+ config?.SetInnerTrace(prefetchTrace);
+
+ using (IEnumerator enumerator = prefetchers.GetEnumerator())
{
- // Observe the remaining tasks
- try
+ if (!enumerator.MoveNext())
{
- await Task.WhenAll(tasks);
+ // literally nothing to prefetch
+ return;
}
- catch
+
+ IPrefetcher first = enumerator.Current;
+
+ if (!enumerator.MoveNext())
{
+ // special case: a single prefetcher... just await it, and skip all the heavy work
+ config?.TaskStarted();
+ config?.TaskAwaited();
+ await first.PrefetchAsync(prefetchTrace, cancellationToken);
+ return;
}
- throw;
+ // need to actually do things to start prefetching in parallel
+ // so grab some state and stash the first two prefetchers off
+
+ initialPrefetchers = RentArray(config, maxConcurrency, clear: false);
+ initialPrefetchers[0] = first;
+ initialPrefetchers[1] = enumerator.Current;
+
+ CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken);
+
+ // batch up a bunch of IPrefetchers to kick off
+ //
+ // we do this separately from starting the Tasks so we can avoid a lock
+ // and quicky get to maxConcurrency degrees of parallelism
+ nextPrefetcherIndex = FillPrefetcherBuffer(commonState, initialPrefetchers, 2, maxConcurrency, enumerator);
+
+ // actually start all the tasks, stashing them in a rented Task[]
+ runningTasks = RentArray(config, nextPrefetcherIndex, clear: false);
+
+ for (nextRunningTaskIndex = 0; nextRunningTaskIndex < nextPrefetcherIndex; nextRunningTaskIndex++)
+ {
+ IPrefetcher toStart = initialPrefetchers[nextRunningTaskIndex];
+ Task startedTask = CommonStartTaskAsync(config, commonState, toStart);
+
+ runningTasks[nextRunningTaskIndex] = startedTask;
+ }
+
+ // hand the prefetcher array back early, so other callers can use it
+ ReturnRentedArray(config, initialPrefetchers, nextPrefetcherIndex);
+ initialPrefetchers = null;
+
+ // now await all Tasks in turn
+ for (int toAwaitTaskIndex = 0; toAwaitTaskIndex < nextRunningTaskIndex; toAwaitTaskIndex++)
+ {
+ Task toAwait = runningTasks[toAwaitTaskIndex];
+
+ try
+ {
+ config?.TaskAwaited();
+ await toAwait;
+ }
+ catch
+ {
+ // if we encountered some exception, tell the remaining tasks to bail
+ // the next time they check commonState
+ commonState.SetFinishedEnumerating();
+
+ // we still need to observe all the tasks we haven't yet to avoid an UnobservedTaskException
+ for (int awaitAndIgnoreTaskIndex = toAwaitTaskIndex + 1; awaitAndIgnoreTaskIndex < nextRunningTaskIndex; awaitAndIgnoreTaskIndex++)
+ {
+ try
+ {
+ config?.TaskAwaited();
+ await runningTasks[awaitAndIgnoreTaskIndex];
+ }
+ catch
+ {
+ // intentionally left empty, we swallow all errors after the first
+ }
+ }
+
+ throw;
+ }
+ }
}
+ }
+ }
+ finally
+ {
+ ReturnRentedArray(config, initialPrefetchers, nextPrefetcherIndex);
+ ReturnRentedArray(config, runningTasks, nextRunningTaskIndex);
+ }
+ }
- if (prefetchersEnumerator.MoveNext())
+ ///
+ /// The case where maxConcurrency is greater than BatchLimit.
+ ///
+ /// This starts up to maxConcurrency simultanous Tasks, doing so in batches
+ /// of BatchLimit (or less) size. Active Tasks are tracked in a psuedo-linked-list
+ /// over rented object[].
+ ///
+ /// This is more complicated, less likely to hit maxConcurrency degrees of
+ /// parallelism, and less allocation efficient when compared to LowConcurrencyPrefetchInParallelAsync.
+ ///
+ /// However, it doesn't allocate gigantic arrays and doesn't wait for full enumeration
+ /// before starting to prefetch.
+ ///
+ private static async Task HighConcurrencyPrefetchInParallelAsync(
+ IEnumerable prefetchers,
+ int maxConcurrency,
+ ITrace trace,
+ ParallelPrefetchTestConfig config,
+ CancellationToken cancellationToken)
+ {
+ IPrefetcher[] currentBatch = null;
+
+ // this ends up holding a sort of linked list where
+ // each entry is actually a Task until the very last one
+ // which is an object[]
+ //
+ // as soon as a null is encountered, either where a Task or
+ // an object[] is expected, the linked list is done
+ object[] runningTasks = null;
+
+ try
+ {
+ using (ITrace prefetchTrace = CommonStartTrace(trace))
+ {
+ config?.SetInnerTrace(prefetchTrace);
+
+ using (IEnumerator enumerator = prefetchers.GetEnumerator())
{
- IPrefetcher bufferable = prefetchersEnumerator.Current;
- tasks.Add(Task.Run(async () => await bufferable.PrefetchAsync(prefetchTrace, cancellationToken)));
+ if (!enumerator.MoveNext())
+ {
+ // no prefetchers at all
+ return;
+ }
+
+ IPrefetcher first = enumerator.Current;
+
+ if (!enumerator.MoveNext())
+ {
+ // special case: a single prefetcher... just await it, and skip all the heavy work
+ config?.TaskStarted();
+ config?.TaskAwaited();
+ await first.PrefetchAsync(prefetchTrace, cancellationToken);
+ return;
+ }
+
+ // need to actually do things to start prefetching in parallel
+ // so grab some state and stash the first two prefetchers off
+
+ currentBatch = RentArray(config, BatchLimit, clear: false);
+ currentBatch[0] = first;
+ currentBatch[1] = enumerator.Current;
+
+ // we need this all null because we use null as a stopping condition later
+ runningTasks = RentArray(config, BatchLimit, clear: true);
+
+ CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken);
+
+ // what we do here is buffer up to BatchLimit IPrefetchers to start
+ // and then... start them all
+ //
+ // we stagger this so we quickly get a bunch of tasks started without spending too
+ // much time pre-loading everything
+
+ // grab our first bunch of prefetchers outside of the lock
+ //
+ // we know that maxConcurrency > BatchLimit, so can just pass it as our cutoff here
+ int bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 2, BatchLimit, enumerator);
+
+ int nextChunkIndex = 0;
+ object[] currentChunk = runningTasks;
+
+ int remainingConcurrency = maxConcurrency;
+
+ // if we encounter any error, we remember it
+ // but as soon as we start a single task we've got
+ // to see most of this code through so we observe them
+ ExceptionDispatchInfo capturedException = null;
+
+ while (true)
+ {
+ // start and store the last set of Tasks we got from FillPrefetcherBuffer
+ for (int toStartIndex = 0; toStartIndex < bufferedPrefetchers; toStartIndex++)
+ {
+ IPrefetcher prefetcher = currentBatch[toStartIndex];
+ Task startedTask = CommonStartTaskAsync(config, commonState, prefetcher);
+
+ currentChunk[nextChunkIndex] = startedTask;
+ nextChunkIndex++;
+
+ // check if we need a new slab to store tasks
+ if (nextChunkIndex == currentChunk.Length - 1)
+ {
+ // we need this all null because we use null as a stopping condition later
+ object[] newChunk = RentArray(config, BatchLimit, clear: true);
+
+ currentChunk[currentChunk.Length - 1] = newChunk;
+
+ currentChunk = newChunk;
+ nextChunkIndex = 0;
+ }
+ }
+
+ remainingConcurrency -= bufferedPrefetchers;
+
+ // check to see if we've started all the concurrent Tasks we can
+ if (remainingConcurrency == 0)
+ {
+ break;
+ }
+
+ int nextBatchSizeLimit = remainingConcurrency < BatchLimit ? remainingConcurrency : BatchLimit;
+
+ // if one of the previously started Tasks exhausted the enumerator
+ // we're done, even if we still have space
+ if (commonState.FinishedEnumerating)
+ {
+ break;
+ }
+
+ // now that Tasks have started, we MUST synchronize access to
+ // the enumerator
+ lock (commonState)
+ {
+ // the answer might have changed, so we double-check
+ // this once we've got the lock
+ if (commonState.FinishedEnumerating)
+ {
+ break;
+ }
+
+ // grab the next set of prefetchers to start
+ try
+ {
+ bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 0, nextBatchSizeLimit, enumerator);
+ }
+ catch (Exception exc)
+ {
+ // this can get raised if the enumerator faults
+ //
+ // in this case we might have some tasks started, and so we need to _stop_ starting new tasks but
+ // still move on to observing everything we've already started
+
+ commonState.SetFinishedEnumerating();
+ capturedException = ExceptionDispatchInfo.Capture(exc);
+
+ break;
+ }
+ }
+
+ // if we got nothing back, we can break right here
+ if (bufferedPrefetchers == 0)
+ {
+ break;
+ }
+ }
+
+ // hand the prefetch array back, we're done with it
+ ReturnRentedArray(config, currentBatch, BatchLimit);
+ currentBatch = null;
+
+ // now wait for all the tasks to complete
+ //
+ // we walk through all of them, even if we encounter an error
+ // because we need to walk the whole linked-list and this is
+ // simpler than an explicit error code path
+
+ int toAwaitIndex = 0;
+ while (runningTasks != null)
+ {
+ Task toAwait = (Task)runningTasks[toAwaitIndex];
+
+ // if we see a null, we're done
+ if (toAwait == null)
+ {
+ // hand the last of the arrays back
+ ReturnRentedArray(config, runningTasks, toAwaitIndex);
+ runningTasks = null;
+
+ break;
+ }
+
+ try
+ {
+ config?.TaskAwaited();
+ await toAwait;
+ }
+ catch (Exception ex)
+ {
+ if (capturedException == null)
+ {
+ // if we encountered some exception, tell the remaining tasks to bail
+ // the next time they check commonState
+ commonState.SetFinishedEnumerating();
+
+ // save the exception so we can rethrow it later
+ capturedException = ExceptionDispatchInfo.Capture(ex);
+ }
+ }
+
+ // advance, moving to the next chunk if we've hit that limit
+ toAwaitIndex++;
+
+ if (toAwaitIndex == runningTasks.Length - 1)
+ {
+ object[] oldChunk = runningTasks;
+
+ runningTasks = (object[])runningTasks[runningTasks.Length - 1];
+ toAwaitIndex = 0;
+
+ // we're done with this, let some other caller reuse it immediately
+ ReturnRentedArray(config, oldChunk, oldChunk.Length);
+ }
+ }
+
+ // fault, if any task failed, after we've finished cleaning up
+ capturedException?.Throw();
}
}
}
+ finally
+ {
+ // cleanup if something went wrong while these were still rented
+ //
+ // this can basically only happen if the enumerator itself faults
+ // which is unlikely, but far from impossible
+
+ ReturnRentedArray(config, currentBatch, BatchLimit);
+
+ while (runningTasks != null)
+ {
+ object[] oldChunk = runningTasks;
+
+ runningTasks = (object[])runningTasks[runningTasks.Length - 1];
+
+ ReturnRentedArray(config, oldChunk, oldChunk.Length);
+ }
+ }
}
}
}
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml
index 7577f2ae4f..460eaa593c 100644
--- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml
@@ -41,7 +41,6 @@
│ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -246,10 +245,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -771,7 +766,6 @@
│ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -980,10 +974,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -1522,7 +1512,6 @@
│ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -1727,10 +1716,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -2253,7 +2238,6 @@
│ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -2462,10 +2446,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -3028,7 +3008,6 @@
│ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -3287,10 +3266,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -3834,7 +3809,6 @@
│ │ └── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -4057,10 +4031,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
@@ -4588,7 +4558,6 @@
│ │ └── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds
│ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
- │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
│ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds
@@ -4815,10 +4784,6 @@
"name": "MoveNextAsync",
"duration in milliseconds": 0,
"children": [
- {
- "name": "Prefetching",
- "duration in milliseconds": 0
- },
{
"name": "[,05C1CFFFFFFFF8) move next",
"duration in milliseconds": 0,
diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs
new file mode 100644
index 0000000000..84511c8337
--- /dev/null
+++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs
@@ -0,0 +1,1255 @@
+namespace Microsoft.Azure.Cosmos.Tests.Pagination
+{
+ using System;
+ using System.Buffers;
+ using System.Collections;
+ using System.Collections.Concurrent;
+ using System.Collections.Generic;
+ using System.Diagnostics;
+ using System.Linq;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Azure.Cosmos.Pagination;
+ using Microsoft.Azure.Cosmos.Tracing;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ [TestClass]
+ public class ParallelPrefetchTests
+ {
+ ///
+ /// IPrefetcher which can only be run once, and invokes callbacks as it executes.
+ ///
+ private sealed class TestOncePrefetcher : IPrefetcher
+ {
+ private int hasRun;
+
+ private readonly Action beforeAwait;
+ private readonly Action afterAwait;
+
+ internal TestOncePrefetcher(Action beforeAwait, Action afterAwait)
+ {
+ this.hasRun = 0;
+ this.beforeAwait = beforeAwait;
+ this.afterAwait = afterAwait;
+ }
+
+ public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
+ {
+ // test that ParallelPrefetch doesn't start the same Task twice.
+ int oldRun = Interlocked.Exchange(ref this.hasRun, 1);
+ Assert.AreEqual(0, oldRun);
+
+ // we use two callbacks to test that ParallelPrefetch is correctly monitoring
+ // continuations - without this, we might incorrectly consider a Task completed
+ // despite it awaiting an inner Task
+
+ this.beforeAwait();
+
+ await Task.Yield();
+ cancellationToken.ThrowIfCancellationRequested();
+
+ this.afterAwait();
+ cancellationToken.ThrowIfCancellationRequested();
+ }
+ }
+
+ ///
+ /// IPrefetcher that does complicated things.
+ ///
+ private sealed class ComplicatedPrefetcher : IPrefetcher
+ {
+ public long StartTimestamp { get; private set; }
+ public long AfterYieldTimestamp { get; private set; }
+ public long AfterDelay1Timestamp { get; private set; }
+ public long AfterSemaphoreTimestamp { get; private set; }
+ public long AfterDelay2Timestamp { get; private set; }
+ public long AfterDelay3Timestamp { get; private set; }
+ public long AfterDelay4Timestamp { get; private set; }
+ public long WhenAllTimestamp { get; private set; }
+ public long EndTimestamp { get; private set; }
+
+ public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
+ {
+ this.StartTimestamp = Stopwatch.GetTimestamp();
+
+ await Task.Yield();
+
+ this.AfterYieldTimestamp = Stopwatch.GetTimestamp();
+
+ using (SemaphoreSlim semaphore = new(0, 1))
+ {
+ Task delay = Task.Delay(5, cancellationToken).ContinueWith(_ => { this.AfterDelay1Timestamp = Stopwatch.GetTimestamp(); semaphore.Release(); }, cancellationToken);
+
+ await semaphore.WaitAsync(cancellationToken);
+ this.AfterSemaphoreTimestamp = Stopwatch.GetTimestamp();
+
+ await delay;
+ }
+
+ await Task.WhenAll(
+ Task.Delay(2, cancellationToken).ContinueWith(_ => this.AfterDelay2Timestamp = Stopwatch.GetTimestamp(), cancellationToken),
+ Task.Delay(3, cancellationToken).ContinueWith(_ => this.AfterDelay3Timestamp = Stopwatch.GetTimestamp(), cancellationToken),
+ Task.Delay(4, cancellationToken).ContinueWith(_ => this.AfterDelay4Timestamp = Stopwatch.GetTimestamp(), cancellationToken));
+ this.WhenAllTimestamp = Stopwatch.GetTimestamp();
+
+ await Task.Yield();
+
+ this.EndTimestamp = Stopwatch.GetTimestamp();
+ }
+
+ internal void AssertCorrect()
+ {
+ Assert.IsTrue(this.StartTimestamp > 0);
+
+ Assert.IsTrue(this.AfterYieldTimestamp > this.StartTimestamp);
+ Assert.IsTrue(this.AfterDelay1Timestamp > this.AfterYieldTimestamp);
+ Assert.IsTrue(this.AfterSemaphoreTimestamp > this.AfterDelay1Timestamp);
+
+ // these can all fire in any order (delay doesn't guarantee any particular order)
+ Assert.IsTrue(this.AfterDelay2Timestamp > this.AfterSemaphoreTimestamp);
+ Assert.IsTrue(this.AfterDelay3Timestamp > this.AfterSemaphoreTimestamp);
+ Assert.IsTrue(this.AfterDelay4Timestamp > this.AfterSemaphoreTimestamp);
+
+ // but by WhenAll()'ing them, we can assert WhenAll completes after all the other delays
+ Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay2Timestamp);
+ Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay3Timestamp);
+ Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay4Timestamp);
+
+ Assert.IsTrue(this.EndTimestamp > this.WhenAllTimestamp);
+ }
+ }
+
+ ///
+ /// IPrefetcher that asserts it got a trace with an expected parent.
+ ///
+ private sealed class ExpectedParentTracePrefetcher : IPrefetcher
+ {
+ private readonly ITrace expectedParentTrace;
+
+ internal ExpectedParentTracePrefetcher(ITrace expectedParentTrace)
+ {
+ this.expectedParentTrace = expectedParentTrace;
+ }
+
+ public ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
+ {
+ Assert.AreSame(this.expectedParentTrace, trace.Parent);
+
+ return default;
+ }
+ }
+
+ ///
+ /// IEnumerable which throws if touched.
+ ///
+ private sealed class ThrowsEnumerable : IEnumerable
+ {
+ public IEnumerator GetEnumerator()
+ {
+ throw new NotSupportedException();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return this.GetEnumerator();
+ }
+ }
+
+ ///
+ /// IEnumerable whose IEnumerator throws if access concurrently.
+ ///
+ private sealed class NonConcurrentAssertingEnumerable : IEnumerable
+ {
+ private sealed class Enumerator : IEnumerator
+ {
+ private readonly T[] inner;
+ private int index;
+
+ private int active;
+
+ internal Enumerator(T[] inner)
+ {
+ this.inner = inner;
+ }
+
+ public T Current { get; private set; }
+
+ object IEnumerator.Current => this.Current;
+
+ public void Dispose()
+ {
+ }
+
+ public bool MoveNext()
+ {
+ int isActive = Interlocked.Exchange(ref this.active, 1);
+ Assert.AreEqual(0, isActive, "Modified concurrently");
+
+ try
+ {
+ if (this.index < this.inner.Length)
+ {
+ this.Current = this.inner[this.index];
+ this.index++;
+ return true;
+ }
+
+ return false;
+ }
+ finally
+ {
+ int wasActive = Interlocked.Exchange(ref this.active, 0);
+ Assert.AreEqual(1, wasActive, "Modified concurrently");
+ }
+ }
+
+ public void Reset()
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ private readonly T[] inner;
+
+ internal NonConcurrentAssertingEnumerable(T[] inner)
+ {
+ this.inner = inner;
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ return new Enumerator(this.inner);
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return this.GetEnumerator();
+ }
+ }
+
+ ///
+ /// IEnumerable whose IEnumerator throws if access concurrently.
+ ///
+ private sealed class DisposeTrackingEnumerable : IEnumerable
+ {
+ private sealed class Enumerator : IEnumerator
+ {
+ private readonly DisposeTrackingEnumerable outer;
+
+ private int index;
+
+ internal Enumerator(DisposeTrackingEnumerable outer)
+ {
+ this.outer = outer;
+ }
+
+ public T Current { get; private set; }
+
+ object IEnumerator.Current => this.Current;
+
+ public void Dispose()
+ {
+ this.outer.DisposeCalls++;
+ }
+
+ public bool MoveNext()
+ {
+ if (this.index < this.outer.inner.Length)
+ {
+ this.Current = this.outer.inner[this.index];
+ this.index++;
+ return true;
+ }
+
+ return false;
+ }
+
+ public void Reset()
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ private readonly T[] inner;
+
+ internal DisposeTrackingEnumerable(T[] inner)
+ {
+ this.inner = inner;
+ }
+
+ internal int DisposeCalls { get; private set; }
+
+ public IEnumerator GetEnumerator()
+ {
+ return new Enumerator(this);
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return this.GetEnumerator();
+ }
+ }
+
+ ///
+ /// IEnumerable whose IEnumerator throws after a certain number of
+ /// calls to MoveNext().
+ ///
+ private sealed class ThrowsAfterEnumerable : IEnumerable
+ {
+ private sealed class Enumerator : IEnumerator
+ {
+ private readonly IEnumerator inner;
+ private readonly int throwAfter;
+ private int callNumber;
+
+ public T Current { get; set; }
+
+ object IEnumerator.Current => this.Current;
+
+
+ internal Enumerator(IEnumerator inner, int throwAfter)
+ {
+ this.inner = inner;
+ this.throwAfter = throwAfter;
+ }
+ public void Dispose()
+ {
+ this.inner.Dispose();
+ }
+
+ public bool MoveNext()
+ {
+ if (this.callNumber >= this.throwAfter)
+ {
+ throw new InvalidOperationException();
+ }
+
+ this.callNumber++;
+
+ if (this.inner.MoveNext())
+ {
+ this.Current = this.inner.Current;
+ return true;
+ }
+
+ this.Current = default;
+ return false;
+ }
+
+ public void Reset()
+ {
+ this.inner.Reset();
+ }
+ }
+
+ private readonly IEnumerable inner;
+ private readonly int throwAfter;
+
+ public ThrowsAfterEnumerable(IEnumerable inner, int throwAfter)
+ {
+ this.inner = inner;
+ this.throwAfter = throwAfter;
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ return new Enumerator(this.inner.GetEnumerator(), this.throwAfter);
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return this.GetEnumerator();
+ }
+ }
+
+ ///
+ /// IPrefetcher which throws if touched.
+ ///
+ private sealed class ThrowsPrefetcher : IPrefetcher
+ {
+ public ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
+ {
+ throw new NotSupportedException();
+ }
+ }
+
+ ///
+ /// ArrayPool that tracks leaks, double returns, and includes non-null values in
+ /// returned arrays.
+ ///
+ private sealed class ValidatingRandomizedArrayPool : ArrayPool
+ where T : class
+ {
+ private readonly T existingValue;
+
+ private readonly ConcurrentBag created;
+ private readonly ConcurrentDictionary rented;
+
+ internal ValidatingRandomizedArrayPool(T existingValue)
+ {
+ this.existingValue = existingValue;
+ this.created = new();
+ this.rented = new();
+ }
+
+ public override T[] Rent(int minimumLength)
+ {
+ int extra = Random.Shared.Next(6);
+
+ T[] ret = new T[minimumLength + extra];
+ for (int i = 0; i < ret.Length; i++)
+ {
+ ret[i] = Random.Shared.Next(2) == 0 ? this.existingValue : null;
+ }
+
+ this.created.Add(ret);
+
+ Assert.IsTrue(this.rented.TryAdd(ret, null));
+
+ return ret;
+ }
+
+ public override void Return(T[] array, bool clearArray = false)
+ {
+ Assert.IsFalse(clearArray, "Caller should clean up array itself");
+
+ Assert.IsTrue(this.rented.TryRemove(array, out _), "Tried to return array that isn't rented");
+
+ for (int i = 0; i < array.Length; i++)
+ {
+ object value = array[i];
+
+ if (object.ReferenceEquals(value, this.existingValue))
+ {
+ continue;
+ }
+
+ Assert.IsNull(value, "Returned array shouldn't have any non-null values, except those included by the original Rent call");
+ }
+ }
+
+ internal void AssertAllReturned()
+ {
+ Assert.IsTrue(this.rented.IsEmpty);
+ }
+ }
+
+ ///
+ /// ITrace which only traces children and parents.
+ ///
+ private sealed class SimpleTrace : ITrace
+ {
+ public string Name { get; private set; }
+
+ public Guid Id { get; } = Guid.NewGuid();
+
+ public DateTime StartTime { get; } = DateTime.UtcNow;
+
+ public TimeSpan Duration => DateTime.UtcNow - this.StartTime;
+
+ public Cosmos.Tracing.TraceLevel Level { get; private set; }
+
+ public TraceComponent Component { get; private set; }
+
+ public TraceSummary Summary => new();
+
+ public ITrace Parent { get; private set; }
+
+ public IReadOnlyList Children { get; } = new List();
+
+ public IReadOnlyDictionary Data { get; } = new Dictionary();
+
+ internal SimpleTrace(ITrace parent, string name, TraceComponent component, Cosmos.Tracing.TraceLevel level)
+ {
+ this.Parent = parent;
+ this.Name = name;
+ this.Component = component;
+ this.Level = level;
+ }
+
+ public void AddChild(ITrace trace)
+ {
+ List children = (List)this.Children;
+ lock (children)
+ {
+ children.Add(trace);
+ }
+ }
+
+ public void AddDatum(string key, TraceDatum traceDatum)
+ {
+ }
+
+ public void AddDatum(string key, object value)
+ {
+ }
+
+ public void AddOrUpdateDatum(string key, object value)
+ {
+ }
+
+ public void Dispose()
+ {
+ }
+
+ public ITrace StartChild(string name)
+ {
+ return this.StartChild(name, TraceComponent.Unknown, Cosmos.Tracing.TraceLevel.Off);
+ }
+
+ public ITrace StartChild(string name, TraceComponent component, Cosmos.Tracing.TraceLevel level)
+ {
+ ITrace child = new SimpleTrace(this, name, component, level);
+
+ List children = (List)this.Children;
+ lock (children)
+ {
+ children.Add(child);
+ }
+
+ return child;
+ }
+ }
+
+ ///
+ /// Different task counts which explore different code paths.
+ ///
+ private static readonly int[] TaskCounts = new[] { 0, 1, 2, 511, 512, 513, 1024, 1025 };
+
+ ///
+ /// Different max concurrencies which explore different code paths.
+ ///
+ private static readonly int[] Concurrencies = new[] { 1, 2, 511, 512, 513, int.MaxValue };
+
+ private static readonly ITrace EmptyTrace = NoOpTrace.Singleton;
+
+ [TestMethod]
+ public async Task ParameterValidationAsync()
+ {
+ // test contract for parameters
+
+ ArgumentNullException prefetchersArg = Assert.ThrowsException(
+ () =>
+ ParallelPrefetch.PrefetchInParallelAsync(
+ null,
+ 123,
+ EmptyTrace,
+ default));
+ Assert.AreEqual("prefetchers", prefetchersArg.ParamName);
+
+ ArgumentNullException traceArg = Assert.ThrowsException(
+ () =>
+ ParallelPrefetch.PrefetchInParallelAsync(
+ Array.Empty(),
+ 123,
+ null,
+ default));
+ Assert.AreEqual("trace", traceArg.ParamName);
+
+ // maxConcurrency can be < 0 ; check that that doesn't throw
+ await ParallelPrefetch.PrefetchInParallelAsync(Array.Empty(), -123, EmptyTrace, default);
+ }
+
+ [TestMethod]
+ public async Task ZeroConcurrencyOptimizationAsync()
+ {
+ // test that we correctly special case maxConcurrency == 0 as "do nothing"
+
+ IEnumerable prefetchers = new ThrowsEnumerable();
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ 0,
+ EmptyTrace,
+ default);
+ }
+
+ [TestMethod]
+ public async Task AllExecutedAsync()
+ {
+ // test that all prefetchers are actually invoked
+
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ int executed1 = 0;
+ int executed2 = 0;
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, () => Interlocked.Increment(ref executed1), () => Interlocked.Increment(ref executed2));
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ default);
+
+ Assert.AreEqual(taskCount, executed1);
+ Assert.AreEqual(taskCount, executed2);
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task EnumeratorNotConcurrentlyAccessedAsync()
+ {
+ // test that the IEnumerator is only accessed by one thread at a time
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { });
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ default);
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ IPrefetcher[] inner = new IPrefetcher[count];
+ for (int i = 0; i < count; i++)
+ {
+ inner[i] = new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+
+ return new NonConcurrentAssertingEnumerable(inner);
+ }
+ }
+
+ [TestMethod]
+ public async Task EnumeratorDisposedAsync()
+ {
+ // test that the IEnumerator is only accessed by one thread at a time
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ DisposeTrackingEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { });
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ default);
+
+ Assert.AreEqual(1, prefetchers.DisposeCalls);
+ }
+ }
+
+ static DisposeTrackingEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ IPrefetcher[] inner = new IPrefetcher[count];
+ for (int i = 0; i < count; i++)
+ {
+ inner[i] = new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+
+ return new DisposeTrackingEnumerable(inner);
+ }
+ }
+
+ [TestMethod]
+ public async Task ComplicatedPrefetcherAsync()
+ {
+ // test that a complicated prefetcher is full started and completed
+ //
+ // the rest of the tests don't use a completely trivial
+ // IPrefetcher, but they are substantially simpler
+
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ ComplicatedPrefetcher prefetcher = new();
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ new IPrefetcher[] { prefetcher },
+ maxConcurrency,
+ EmptyTrace,
+ default);
+
+ prefetcher.AssertCorrect();
+ }
+ }
+
+ [TestMethod]
+ public async Task MaxConcurrencyRespectedAsync()
+ {
+ // test that we never get above maxConcurrency
+ //
+ // whether or not we _reach_ it is dependent on the scheduler
+ // so we can't reliably test that
+
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ int observedMax = 0;
+ int current = 0;
+
+ IEnumerable prefetchers =
+ CreatePrefetchers(
+ taskCount,
+ () =>
+ {
+ int newCurrent = Interlocked.Increment(ref current);
+ Assert.IsTrue(newCurrent <= maxConcurrency);
+
+ int oldMax = Volatile.Read(ref observedMax);
+
+ while (newCurrent > oldMax)
+ {
+ oldMax = Interlocked.CompareExchange(ref observedMax, newCurrent, oldMax);
+ }
+ },
+ () =>
+ {
+ int newCurrent = Interlocked.Decrement(ref current);
+
+ Assert.IsTrue(current >= 0);
+ });
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ default);
+
+ Assert.IsTrue(Volatile.Read(ref observedMax) <= maxConcurrency);
+ Assert.AreEqual(0, Volatile.Read(ref current));
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task TraceCorrectlyPassedAsync()
+ {
+ // test that we make ONE ITrace per invocation
+ // and that it the returned child trace is correctly
+ // passed to all IPrefetchers
+
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ using ITrace simpleTrace = new SimpleTrace(null, "Root", TraceComponent.Batch, Cosmos.Tracing.TraceLevel.Off);
+
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, simpleTrace);
+
+ await ParallelPrefetch.PrefetchInParallelAsync(
+ prefetchers,
+ maxConcurrency,
+ simpleTrace,
+ default);
+
+ // our prefetchers don't create any children, but we expect one
+ // to be created by ParallelPrefetch
+ Assert.AreEqual(1, simpleTrace.Children.Count);
+ Assert.AreEqual(0, simpleTrace.Children[0].Children.Count);
+
+ // the one trace we start has a well known set of attributes, so check them
+ Assert.AreEqual("Prefetching", simpleTrace.Children[0].Name);
+ Assert.AreEqual(TraceComponent.Pagination, simpleTrace.Children[0].Component);
+ Assert.AreEqual(Cosmos.Tracing.TraceLevel.Info, simpleTrace.Children[0].Level);
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, ITrace expectedParentTrace)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new ExpectedParentTracePrefetcher(expectedParentTrace);
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task RentedBuffersAllReturnedAsync()
+ {
+ // test that all rented buffers are correctly returned
+ // (and in the expected state)
+
+ Task faultedTask = Task.FromException(new NotSupportedException());
+
+ try
+ {
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ foreach (int taskCount in TaskCounts)
+ {
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { });
+
+ ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher());
+ ValidatingRandomizedArrayPool taskPool = new(faultedTask);
+ ValidatingRandomizedArrayPool objectPool = new("unexpected value");
+
+ ParallelPrefetch.ParallelPrefetchTestConfig config =
+ new(
+ prefetcherPool,
+ taskPool,
+ objectPool
+ );
+
+ await ParallelPrefetch.PrefetchInParallelCoreAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ config,
+ default);
+
+ Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}; some tasks left unawaited");
+
+ prefetcherPool.AssertAllReturned();
+ taskPool.AssertAllReturned();
+ objectPool.AssertAllReturned();
+ }
+ }
+ }
+ finally
+ {
+ // observe this intentionally faulted task, no matter what
+ try
+ {
+ await faultedTask;
+ }
+ catch
+ {
+ // intentionally empty
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task TaskSingleExceptionHandledAsync()
+ {
+ // test that raising exceptions during processing tasks
+ // doesn't leak or otherwise fail
+
+ Task faultedTask = Task.FromException(new NotSupportedException());
+
+ try
+ {
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ if (maxConcurrency <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ foreach (int taskCount in TaskCounts)
+ {
+ if (taskCount <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ for (int faultOnTask = 0; faultOnTask < taskCount; faultOnTask++)
+ {
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, faultOnTask, static () => { }, static () => { });
+
+ ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher());
+ ValidatingRandomizedArrayPool taskPool = new(faultedTask);
+ ValidatingRandomizedArrayPool objectPool = new("unexpected value");
+
+ ParallelPrefetch.ParallelPrefetchTestConfig config =
+ new(
+ prefetcherPool,
+ taskPool,
+ objectPool
+ );
+
+ Exception caught = null;
+ try
+ {
+ await ParallelPrefetch.PrefetchInParallelCoreAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ config,
+ default);
+ }
+ catch (Exception e)
+ {
+ caught = e;
+ }
+
+ Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultOn={faultOnTask} - didn't produce exception as expected");
+
+ Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, faultOnTask={faultedTask}; some tasks left unawaited");
+
+ // buffer management can't break in the face of errors, so check here too
+ prefetcherPool.AssertAllReturned();
+ taskPool.AssertAllReturned();
+ objectPool.AssertAllReturned();
+ }
+ }
+ }
+ }
+ finally
+ {
+ // observe this intentionally faulted task, no matter what
+ try
+ {
+ await faultedTask;
+ }
+ catch
+ {
+ // intentionally empty
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ if (faultOnTask == i)
+ {
+ yield return new ThrowsPrefetcher();
+ }
+ else
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task TaskMultipleExceptionsHandledAsync()
+ {
+ // we throw a lot of exceptions in this test, which is expensive
+ // so we only probe proportionally to this constant for expediency's
+ // sake
+ const int StepRatio = 10;
+
+ // test that raising exceptions during processing tasks
+ // doesn't leak or otherwise fail
+
+ Task faultedTask = Task.FromException(new NotSupportedException());
+
+ try
+ {
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ if (maxConcurrency <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ foreach (int taskCount in TaskCounts)
+ {
+ if (taskCount <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ int step = Math.Max(1, taskCount / StepRatio);
+
+ for (int faultOnAndAfterTask = 0; faultOnAndAfterTask < taskCount; faultOnAndAfterTask += step)
+ {
+ IEnumerable prefetchers = CreatePrefetchers(taskCount, faultOnAndAfterTask, static () => { }, static () => { });
+
+ ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher());
+ ValidatingRandomizedArrayPool taskPool = new(faultedTask);
+ ValidatingRandomizedArrayPool objectPool = new("unexpected value");
+
+ ParallelPrefetch.ParallelPrefetchTestConfig config =
+ new(
+ prefetcherPool,
+ taskPool,
+ objectPool
+ );
+
+ Exception caught = null;
+ try
+ {
+ await ParallelPrefetch.PrefetchInParallelCoreAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ config,
+ default);
+ }
+ catch (Exception e)
+ {
+ caught = e;
+ }
+
+ Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultOnAndAfterTask={faultOnAndAfterTask} - didn't produce exception as expected");
+
+ Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, faultedOnAndAfterTask={faultOnAndAfterTask}; some tasks left unawaited");
+
+ // buffer management can't break in the face of errors, so check here too
+ prefetcherPool.AssertAllReturned();
+ taskPool.AssertAllReturned();
+ objectPool.AssertAllReturned();
+ }
+ }
+ }
+ }
+ finally
+ {
+ // observe this intentionally faulted task, no matter what
+ try
+ {
+ await faultedTask;
+ }
+ catch
+ {
+ // intentionally empty
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ if (faultOnTask >= i)
+ {
+ yield return new ThrowsPrefetcher();
+ }
+ else
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task EnumerableExceptionsHandledAsync()
+ {
+ // test that raising exceptions during enumeration
+ // doesn't leak or otherwise fail
+
+ Task faultedTask = Task.FromException(new NotSupportedException());
+
+ try
+ {
+ foreach (int maxConcurrency in Concurrencies.Reverse())
+ {
+ if (maxConcurrency <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ foreach (int taskCount in TaskCounts)
+ {
+ for (int faultAfter = 0; faultAfter < taskCount; faultAfter++)
+ {
+ IEnumerable prefetchersRaw = CreatePrefetchers(taskCount, faultAfter, static () => { }, static () => { });
+ IEnumerable prefetchers = new ThrowsAfterEnumerable(prefetchersRaw, faultAfter);
+
+ ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher());
+ ValidatingRandomizedArrayPool taskPool = new(faultedTask);
+ ValidatingRandomizedArrayPool objectPool = new("unexpected value");
+
+ ParallelPrefetch.ParallelPrefetchTestConfig config =
+ new(
+ prefetcherPool,
+ taskPool,
+ objectPool
+ );
+
+ Exception caught = null;
+ try
+ {
+ await ParallelPrefetch.PrefetchInParallelCoreAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ config,
+ default);
+ }
+ catch (Exception e)
+ {
+ caught = e;
+ }
+
+ Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultAfter={faultAfter} - didn't produce exception as expected");
+
+ Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}; some tasks left unawaited");
+
+ // buffer management can't break in the face of errors, so check here too
+ prefetcherPool.AssertAllReturned();
+ taskPool.AssertAllReturned();
+ objectPool.AssertAllReturned();
+ }
+ }
+ }
+ }
+ finally
+ {
+ // observe this intentionally faulted task, no matter what
+ try
+ {
+ await faultedTask;
+ }
+ catch
+ {
+ // intentionally empty
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+
+ [TestMethod]
+ public async Task CancellationHandledAsync()
+ {
+ // cancellation is expensive, so rather than check every
+ // cancellation point - we just probe some proportional
+ // to this constant
+ const int StepRatio = 10;
+
+ // test that cancellation during processing
+ // doesn't leak or otherwise fail
+
+ Task faultedTask = Task.FromException(new NotSupportedException());
+
+ try
+ {
+ foreach (int maxConcurrency in Concurrencies)
+ {
+ if (maxConcurrency <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ foreach (int taskCount in TaskCounts)
+ {
+ if (taskCount <= 1)
+ {
+ // we won't do anything fancy, so skip
+ continue;
+ }
+
+ int step = Math.Max(1, taskCount / StepRatio);
+
+ for (int cancelBeforeTask = 0; cancelBeforeTask < taskCount; cancelBeforeTask += step)
+ {
+ using CancellationTokenSource cts = new();
+
+ int startedBeforeCancellation = 0;
+ object sync = new();
+
+ IEnumerable prefetchers =
+ CreatePrefetchers(
+ taskCount,
+ () =>
+ {
+ if (!cts.IsCancellationRequested)
+ {
+ int newValue = Interlocked.Increment(ref startedBeforeCancellation);
+
+ if (newValue >= cancelBeforeTask)
+ {
+ cts.Cancel();
+ }
+ }
+ },
+ () => { });
+
+ ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher());
+ ValidatingRandomizedArrayPool taskPool = new(faultedTask);
+ ValidatingRandomizedArrayPool objectPool = new("unexpected value");
+
+ ParallelPrefetch.ParallelPrefetchTestConfig config =
+ new(
+ prefetcherPool,
+ taskPool,
+ objectPool
+ );
+
+ Exception caught = null;
+ try
+ {
+ await ParallelPrefetch.PrefetchInParallelCoreAsync(
+ prefetchers,
+ maxConcurrency,
+ EmptyTrace,
+ config,
+ cts.Token);
+ }
+ catch (Exception e)
+ {
+ caught = e;
+ }
+
+ Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, cancelBeforeTask={cancelBeforeTask} - didn't produce exception as expected");
+
+ // we might burst above this, but we should always at least _reach_ it
+ Assert.IsTrue(cancelBeforeTask <= startedBeforeCancellation, $"{cancelBeforeTask} > {startedBeforeCancellation} ; we should have reach our cancellation point");
+
+ Assert.IsTrue(caught is OperationCanceledException);
+
+ Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, cancelBeforeTask={cancelBeforeTask}; some tasks left unawaited");
+
+ // buffer management can't break in the face of cancellation, so check here too
+ prefetcherPool.AssertAllReturned();
+ taskPool.AssertAllReturned();
+ objectPool.AssertAllReturned();
+ }
+ }
+
+ static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait)
+ {
+ for (int i = 0; i < count; i++)
+ {
+ yield return new TestOncePrefetcher(beforeAwait, afterAwait);
+ }
+ }
+ }
+ }
+ finally
+ {
+ // observe this intentionally faulted task, no matter what
+ try
+ {
+ await faultedTask;
+ }
+ catch
+ {
+ // intentionally empty
+ }
+ }
+ }
+ }
+}