diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs index ee56e0d1a18d15..3a7d40b7d0489b 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs @@ -480,18 +480,10 @@ public void Complete() } else { - // Fault with all of the received exceptions, but filter out those due to inner cancellation, - // as they're effectively an implementation detail and stem from the original exception. - Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated."); - for (int i = 0; i < _exceptions.Count; i++) - { - if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token) - { - _exceptions[i] = null!; - } - } - _exceptions.RemoveAll(e => e is null); - Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation."); + // Fail the task with the resulting exceptions. The first should be the initial + // exception that triggered the operation to shut down. The others, if any, may + // include cancellation exceptions from other concurrent operations being canceled + // in response to the primary exception. taskSet = TrySetException(_exceptions); } diff --git a/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs b/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs index 91747d3828b622..97ac99bc159600 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs @@ -618,6 +618,64 @@ static async IAsyncEnumerable Iterate() Assert.True(t.IsCanceled); } + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task Cancellation_FaultsForOceForNonCancellation(bool internalToken) + { + static async IAsyncEnumerable Iterate() + { + int counter = 0; + while (true) + { + await Task.Yield(); + yield return counter++; + } + } + + var cts = new CancellationTokenSource(); + + Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token }, (item, cancellationToken) => + { + throw new OperationCanceledException(internalToken ? cancellationToken : cts.Token); + }); + + await Assert.ThrowsAnyAsync(() => t); + Assert.True(t.IsFaulted); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(0, 4)] + [InlineData(1, 4)] + [InlineData(2, 4)] + [InlineData(3, 4)] + [InlineData(4, 4)] + public async Task Cancellation_InternalCancellationExceptionsArentFilteredOut(int numThrowingNonCanceledOce, int total) + { + var cts = new CancellationTokenSource(); + + var barrier = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + int remainingCount = total; + + Task t = Parallel.ForEachAsync(Enumerable.Range(0, total), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = total }, async (item, cancellationToken) => + { + // Wait for all operations to be started + if (Interlocked.Decrement(ref remainingCount) == 0) + { + barrier.SetResult(); + } + await barrier.Task; + + throw item < numThrowingNonCanceledOce ? + new OperationCanceledException(cancellationToken) : + throw new FormatException(); + }); + + await Assert.ThrowsAnyAsync(() => t); + Assert.Equal(total, t.Exception.InnerExceptions.Count); + Assert.Equal(numThrowingNonCanceledOce, t.Exception.InnerExceptions.Count(e => e is OperationCanceledException)); + } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] public void Exception_FromGetEnumerator_Sync() { @@ -672,7 +730,6 @@ static IEnumerable Iterate() Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default); await Assert.ThrowsAsync(() => t); Assert.True(t.IsFaulted); - Assert.Equal(1, t.Exception.InnerExceptions.Count); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -694,7 +751,6 @@ static async IAsyncEnumerable Iterate() Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default); await Assert.ThrowsAsync(() => t); Assert.True(t.IsFaulted); - Assert.Equal(1, t.Exception.InnerExceptions.Count); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -771,7 +827,6 @@ public async Task Exception_FromDispose_Sync() Task t = Parallel.ForEachAsync((IEnumerable)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default); await Assert.ThrowsAsync(() => t); Assert.True(t.IsFaulted); - Assert.Equal(1, t.Exception.InnerExceptions.Count); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -780,7 +835,6 @@ public async Task Exception_FromDispose_Async() Task t = Parallel.ForEachAsync((IAsyncEnumerable)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default); await Assert.ThrowsAsync(() => t); Assert.True(t.IsFaulted); - Assert.Equal(1, t.Exception.InnerExceptions.Count); } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]