Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix corner-case handling of cancellation exception in ForEachAsync
If code in Parallel.ForEachAsync throws OperationCanceledExceptions containing the CancellationToken passed to the iteration and that token has _not_ had cancellation requested (so why are they throwing with it) and there are no other exceptions, the ForEachAsync will effectively hang after failing to complete the task returned from it.

The issue stems from how we treat cancellation.  If the user-supplied token hasn't been canceled but we have OperationCanceledExceptions for the token passed into the iteration (the "internal" token), it can only have been canceled because an exception occurred.  We filter out these cancellation exceptions, leaving just the exceptions that are deemed to have caused the failure in the first place.  But the code doesn't currently account for the possibility that the developer is (arguably erroneously) throwing such an OperationCanceledException with the internal cancellation token as that root failure. The fix is to only filter out these OCEs if there are other exceptions besides them.
  • Loading branch information
stephentoub committed Sep 13, 2021
commit f5d5c4e22f4c7027664c1eb9e6019fe439b89dab
Original file line number Diff line number Diff line change
Expand Up @@ -481,17 +481,40 @@ public void Complete()
else
{
// Fault with all of the received exceptions, but filter out those due to inner cancellation,
// as they're effectively an implementation detail and stem from the original exception.
// as they're effectively an implementation detail and stem from the original exception. However,
// if _all_ of the exceptions are OperationCanceledExceptions that contain the internal cancellation
// token, then they were all thrown by the user code explicitly without cancellation having been
// requested (since without the external token being canceled, the only way internal cancellation
// would be triggered is upon receipt of an exception). In such a case, just treat all of them
// as real exceptions.
Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated.");
for (int i = 0; i < _exceptions.Count; i++)
Debug.Assert(Cancellation.IsCancellationRequested, "Any exception should trigger internal cancellation being requested.");

// Count how many of the exceptions are for internal cancellation
int oceCount = 0;
foreach (Exception e in _exceptions)
{
if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
if (e is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
{
_exceptions[i] = null!;
oceCount++;
}
}
_exceptions.RemoveAll(e => e is null);
Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation.");

// If some but not all were, filter them out.
if (oceCount > 0 && oceCount < _exceptions.Count)
{
for (int i = 0; i < _exceptions.Count; i++)
{
if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
{
_exceptions[i] = null!;
}
}
_exceptions.RemoveAll(e => e is null);
Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation.");
}

// Fail the task with the resulting exceptions
taskSet = TrySetException(_exceptions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,75 @@ static async IAsyncEnumerable<int> Iterate()
Assert.True(t.IsCanceled);
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(false)]
[InlineData(true)]
public async Task Cancellation_FaultsForOceForNonCancellation(bool internalToken)
{
static async IAsyncEnumerable<int> Iterate()
{
int counter = 0;
while (true)
{
await Task.Yield();
yield return counter++;
}
}

var cts = new CancellationTokenSource();

Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token }, (item, cancellationToken) =>
{
throw new OperationCanceledException(internalToken ? cancellationToken : cts.Token);
});

await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
Assert.True(t.IsFaulted);
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(0, 4)]
[InlineData(1, 4)]
[InlineData(2, 4)]
[InlineData(3, 4)]
[InlineData(4, 4)]
public async Task Cancellation_InternalCancellationExceptionsFilteredWhenOtherExceptions(int numThrowingNonCanceledOce, int total)
{
var cts = new CancellationTokenSource();

var barrier = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
int remainingCount = total;

Task t = Parallel.ForEachAsync(Enumerable.Range(0, total), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = total }, async (item, cancellationToken) =>
{
// Wait for all operations to be started
if (Interlocked.Decrement(ref remainingCount) == 0)
{
barrier.SetResult();
}
await barrier.Task;

throw item < numThrowingNonCanceledOce ?
new OperationCanceledException(cancellationToken) :
throw new FormatException();
});

await Assert.ThrowsAnyAsync<Exception>(() => t);
if (numThrowingNonCanceledOce == 0)
{
Assert.Equal(total, t.Exception.InnerExceptions.Count(e => e is FormatException));
}
else if (numThrowingNonCanceledOce == total)
{
Assert.Equal(total, t.Exception.InnerExceptions.Count(e => e is OperationCanceledException));
}
else
{
Assert.Equal(total - numThrowingNonCanceledOce, t.Exception.InnerExceptions.Count);
Assert.All(t.Exception.InnerExceptions, e => Assert.IsType<FormatException>(e));
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public void Exception_FromGetEnumerator_Sync()
{
Expand Down