diff --git a/Microsoft.Azure.Cosmos/src/Pagination/CrossPartitionRangePageAsyncEnumerator.cs b/Microsoft.Azure.Cosmos/src/Pagination/CrossPartitionRangePageAsyncEnumerator.cs index 4feb66363e..1ffca500e6 100644 --- a/Microsoft.Azure.Cosmos/src/Pagination/CrossPartitionRangePageAsyncEnumerator.cs +++ b/Microsoft.Azure.Cosmos/src/Pagination/CrossPartitionRangePageAsyncEnumerator.cs @@ -267,7 +267,7 @@ private static async Task 1) { await ParallelPrefetch.PrefetchInParallelAsync(bufferedEnumerators, maxConcurrency.Value, trace, token); } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemTests.cs index 66bf5a2559..25a4f4c248 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemTests.cs @@ -1629,10 +1629,8 @@ public async Task ItemMultiplePartitionQuery() QueryRequestOptions requestOptions = new QueryRequestOptions() { - MaxBufferedItemCount = 10, - ResponseContinuationTokenLimitInKb = 500, MaxItemCount = 1, - MaxConcurrency = 1, + MaxConcurrency = -1, }; FeedIterator feedIterator = this.Container.GetItemQueryIterator( @@ -1654,7 +1652,8 @@ public async Task ItemMultiplePartitionQuery() ServerSideCumulativeMetrics metrics = iter.Diagnostics.GetQueryMetrics(); if (metrics != null) - { + { + // This assumes that we are using parallel prefetch to hit multiple partitions concurrently Assert.IsTrue(metrics.PartitionedMetrics.Count == 3); Assert.IsTrue(metrics.CumulativeMetrics.TotalTime > TimeSpan.Zero); Assert.IsTrue(metrics.CumulativeMetrics.QueryPreparationTime > TimeSpan.Zero); diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/CrossPartitionPartitionRangeEnumeratorTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/CrossPartitionPartitionRangeEnumeratorTests.cs index e803ab1731..66df149199 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/CrossPartitionPartitionRangeEnumeratorTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/CrossPartitionPartitionRangeEnumeratorTests.cs @@ -11,6 +11,7 @@ namespace Microsoft.Azure.Cosmos.Tests.Pagination using System.Text; using System.Threading; using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.CosmosElements; using Microsoft.Azure.Cosmos.Pagination; using Microsoft.Azure.Cosmos.Query.Core.Monads; using Microsoft.Azure.Cosmos.ReadFeed.Pagination; @@ -161,6 +162,104 @@ PartitionRangePageAsyncEnumerator createEnumerator( Assert.AreEqual(1, createdEnumerators[2].GetNextPageAsyncCounter, "Second enumerator should not be used"); } + [TestMethod] + public async Task TestParallelPrefetch() + { + List feedRanges = new List() + { + new FeedRangeEpk(new Documents.Routing.Range( "", "AA", isMinInclusive: true, isMaxInclusive: false)), + new FeedRangeEpk(new Documents.Routing.Range("AA", "BB", isMinInclusive: true, isMaxInclusive: false)), + new FeedRangeEpk(new Documents.Routing.Range("BB", "CC", isMinInclusive: true, isMaxInclusive: false)), + new FeedRangeEpk(new Documents.Routing.Range("CC", "DD", isMinInclusive: true, isMaxInclusive: false)), + new FeedRangeEpk(new Documents.Routing.Range("DD", "EE", isMinInclusive: true, isMaxInclusive: false)), + new FeedRangeEpk(new Documents.Routing.Range("EE", "FF", isMinInclusive: true, isMaxInclusive: false)), + }; + + Mock mockFeedRangeProvider = new Mock(); + mockFeedRangeProvider.Setup(p => p.GetFeedRangesAsync( + It.IsAny(), + It.IsAny())) + .ReturnsAsync(feedRanges); + + foreach (int maxConcurrency in new[] { 0, 1, 2, 10, 100 }) + { + List rangeEnumerators = feedRanges + .Select(feedRange => new MockEnumerator(new FeedRangeState(feedRange, null), 1)) + .ToList(); + + IEnumerator enumerator = rangeEnumerators.GetEnumerator(); + + MockEnumerator CreateMockEnumerator(FeedRangeState feedRangeState) + { + Assert.IsTrue(enumerator.MoveNext()); + return enumerator.Current; + }; + + CrossPartitionRangePageAsyncEnumerator crossPartitionEnumerator = new CrossPartitionRangePageAsyncEnumerator( + feedRangeProvider: mockFeedRangeProvider.Object, + createPartitionRangeEnumerator: CreateMockEnumerator, + comparer: null, + prefetchPolicy: PrefetchPolicy.PrefetchSinglePage, + maxConcurrency: maxConcurrency, + state: null); + + await crossPartitionEnumerator.MoveNextAsync(NoOpTrace.Singleton, cancellationToken: default); + + if (maxConcurrency <= 1) + { + Assert.AreEqual(1, rangeEnumerators.First().InvocationCount); + Assert.IsTrue(rangeEnumerators.Skip(1).All(x => x.InvocationCount == 0)); + } + else + { + Assert.IsTrue(rangeEnumerators.All(x => x.InvocationCount == 1)); + } + } + } + + private class MockEnumerator : PartitionRangePageAsyncEnumerator + { + private static readonly IReadOnlyDictionary EmptyHeaders = new Dictionary(); + + private static readonly Stream EmptyStream = new MemoryStream(Encoding.UTF8.GetBytes("{\"Documents\": [], \"_count\": 0, \"_rid\": \"asdf\"}")); + + private readonly int pageCount; + + public int InvocationCount { get; private set; } + + public MockEnumerator(FeedRangeState feedRangeState, int pageCount) + : base(feedRangeState) + { + this.pageCount = pageCount; + } + + public override ValueTask DisposeAsync() + { + return default; + } + + protected override Task> GetNextPageAsync(ITrace trace, CancellationToken cancellationToken) + { + if (this.InvocationCount >= this.pageCount) + { + return Task.FromResult(TryCatch.FromException(new InvalidOperationException( + "Trying to move next on an enumerator that is finished"))); + } + + ++this.InvocationCount; + + ReadFeedState state = null; + if (this.InvocationCount < this.pageCount) + { + CosmosElement continuationToken = CosmosString.Create("asdf"); + state = new ReadFeedContinuationState(continuationToken); + } + + return Task.FromResult(TryCatch.FromResult( + new ReadFeedPage(EmptyStream, 2.8, 0, "activityId", EmptyHeaders, state))); + } + } + private class EnumeratorThatSplits : PartitionRangePageAsyncEnumerator { private readonly bool throwError;