diff --git a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs index b06bf9509f3b..bb08dde0799f 100644 --- a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs @@ -246,25 +246,13 @@ public Task ReadAsStreamAsync() return _contentReadStream != null ? Task.FromResult(_contentReadStream) : - ReadAsStreamAsyncCore(); + ReadAsStreamAsyncCore(CreateContentReadStreamAsync()); } - private Task ReadAsStreamAsyncCore() + private async Task ReadAsStreamAsyncCore(Task createContentStreamTask) { - TaskCompletionSource tcs = new TaskCompletionSource(this); - - CreateContentReadStreamAsync().ContinueWithStandard(tcs, (task, state) => - { - var innerTcs = (TaskCompletionSource)state; - var innerThis = (HttpContent)innerTcs.Task.AsyncState; - if (!HttpUtilities.HandleFaultsAndCancelation(task, innerTcs)) - { - innerThis._contentReadStream = task.Result; - innerTcs.TrySetResult(innerThis._contentReadStream); - } - }); - - return tcs.Task; + _contentReadStream = await createContentStreamTask.ConfigureAwait(false); + return _contentReadStream; } protected abstract Task SerializeToStreamAsync(Stream stream, TransportContext context); @@ -277,7 +265,6 @@ public Task CopyToAsync(Stream stream, TransportContext context) throw new ArgumentNullException(nameof(stream)); } - TaskCompletionSource tcs = new TaskCompletionSource(); try { Task task = null; @@ -292,34 +279,24 @@ public Task CopyToAsync(Stream stream, TransportContext context) CheckTaskNotNull(task); } - // If the copy operation fails, wrap the exception in an HttpRequestException() if appropriate. - task.ContinueWithStandard(tcs, (copyTask, state) => - { - var innerTcs = (TaskCompletionSource)state; - if (copyTask.IsFaulted) - { - innerTcs.TrySetException(GetStreamCopyException(copyTask.Exception.GetBaseException())); - } - else if (copyTask.IsCanceled) - { - innerTcs.TrySetCanceled(); - } - else - { - innerTcs.TrySetResult(null); - } - }); + return CopyToAsyncCore(task); } - catch (IOException e) + catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) { - tcs.TrySetException(GetStreamCopyException(e)); + throw GetStreamCopyException(e); } - catch (ObjectDisposedException e) + } + + private static async Task CopyToAsyncCore(Task copyTask) + { + try { - tcs.TrySetException(GetStreamCopyException(e)); + await copyTask.ConfigureAwait(false); + } + catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) + { + throw GetStreamCopyException(e); } - - return tcs.Task; } public Task CopyToAsync(Stream stream) @@ -353,84 +330,59 @@ public Task LoadIntoBufferAsync(long maxBufferSize) return Task.CompletedTask; } - TaskCompletionSource tcs = new TaskCompletionSource(); - Exception error = null; MemoryStream tempBuffer = CreateMemoryStream(maxBufferSize, out error); - if (tempBuffer == null) { - // We don't throw in LoadIntoBufferAsync(): set the task as faulted and return the task. - Debug.Assert(error != null); - tcs.TrySetException(error); + // We don't throw in LoadIntoBufferAsync(): return a faulted task. + return Task.FromException(error); } - else + + try { - try - { - Task task = SerializeToStreamAsync(tempBuffer, null); - CheckTaskNotNull(task); + Task task = SerializeToStreamAsync(tempBuffer, null); + CheckTaskNotNull(task); + return LoadIntoBufferAsyncCore(task, tempBuffer); + } + catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) + { + return Task.FromException(GetStreamCopyException(e)); + } + // other synchronous exceptions from SerializeToStreamAsync/CheckTaskNotNull will propagate + } - task.ContinueWithStandard(copyTask => - { - try - { - if (copyTask.IsFaulted) - { - tempBuffer.Dispose(); // Cleanup partially filled stream. - tcs.TrySetException(GetStreamCopyException(copyTask.Exception.GetBaseException())); - return; - } - - if (copyTask.IsCanceled) - { - tempBuffer.Dispose(); // Cleanup partially filled stream. - tcs.TrySetCanceled(); - return; - } - - tempBuffer.Seek(0, SeekOrigin.Begin); // Rewind after writing data. - _bufferedContent = tempBuffer; - tcs.TrySetResult(null); - } - catch (Exception e) - { - // Make sure we catch any exception, otherwise the task will catch it and throw in the finalizer. - tcs.TrySetException(e); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Exception(NetEventSource.ComponentType.Http, this, "LoadIntoBufferAsync", e); - } - }); - } - catch (IOException e) - { - tcs.TrySetException(GetStreamCopyException(e)); - } - catch (ObjectDisposedException e) - { - tcs.TrySetException(GetStreamCopyException(e)); - } + private async Task LoadIntoBufferAsyncCore(Task serializeToStreamTask, MemoryStream tempBuffer) + { + try + { + await serializeToStreamTask.ConfigureAwait(false); + } + catch (Exception e) + { + tempBuffer.Dispose(); // Cleanup partially filled stream. + Exception we = GetStreamCopyException(e); + if (we != e) throw we; + throw; } - return tcs.Task; + try + { + tempBuffer.Seek(0, SeekOrigin.Begin); // Rewind after writing data. + _bufferedContent = tempBuffer; + } + catch (Exception e) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Exception(NetEventSource.ComponentType.Http, this, nameof(LoadIntoBufferAsync), e); + throw; + } } protected virtual Task CreateContentReadStreamAsync() { - var tcs = new TaskCompletionSource(this); // By default just buffer the content to a memory stream. Derived classes can override this behavior // if there is a better way to retrieve the content as stream (e.g. byte array/string use a more efficient // way, like wrapping a read-only MemoryStream around the bytes/string) - LoadIntoBufferAsync().ContinueWithStandard(tcs, (task, state) => - { - var innerTcs = (TaskCompletionSource)state; - var innerThis = (HttpContent)innerTcs.Task.AsyncState; - if (!HttpUtilities.HandleFaultsAndCancelation(task, innerTcs)) - { - innerTcs.TrySetResult(innerThis._bufferedContent); - } - }); - - return tcs.Task; + return WaitAndReturnAsync(LoadIntoBufferAsync(), this, s => (Stream)s._bufferedContent); } // Derived types return true if they're able to compute the length. It's OK if derived types return false to @@ -540,6 +492,8 @@ private void CheckTaskNotNull(Task task) } } + private static bool StreamCopyExceptionNeedsWrapping(Exception e) => e is IOException || e is ObjectDisposedException; + private static Exception GetStreamCopyException(Exception originalException) { // HttpContent derived types should throw HttpRequestExceptions if there is an error. However, since the stream @@ -551,12 +505,9 @@ private static Exception GetStreamCopyException(Exception originalException) // don't want to hide such "usage error" exceptions in HttpRequestException. // ObjectDisposedException is also wrapped, since aborting HWR after a request is complete will result in // the response stream being closed. - Exception result = originalException; - if ((result is IOException) || (result is ObjectDisposedException)) - { - result = new HttpRequestException(SR.net_http_content_stream_copy_error, result); - } - return result; + return StreamCopyExceptionNeedsWrapping(originalException) ? + new HttpRequestException(SR.net_http_content_stream_copy_error, originalException) : + originalException; } private static int GetPreambleLength(ArraySegment buffer, Encoding encoding) diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpContentTest.cs b/src/System.Net.Http/tests/FunctionalTests/HttpContentTest.cs index bbea4d761671..6453093013fa 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpContentTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpContentTest.cs @@ -574,7 +574,7 @@ protected override Task SerializeToStreamAsync(Stream stream, TransportContext c throw _customException; } - return Task.Factory.StartNew(() => + return Task.Run(() => { CheckThrow(); return stream.WriteAsync(_mockData, 0, _mockData.Length);