diff --git a/src/libraries/Common/src/System/Net/StreamBuffer.cs b/src/libraries/Common/src/System/Net/StreamBuffer.cs index 6759fcdd8e20b0..32bc0f3f4e45ce 100644 --- a/src/libraries/Common/src/System/Net/StreamBuffer.cs +++ b/src/libraries/Common/src/System/Net/StreamBuffer.cs @@ -192,8 +192,6 @@ public void EndWrite() private (bool wait, int bytesRead) TryReadFromBuffer(Span buffer) { - Debug.Assert(buffer.Length > 0); - Debug.Assert(!Monitor.IsEntered(SyncObject)); lock (SyncObject) { @@ -225,11 +223,6 @@ public void EndWrite() public int Read(Span buffer) { - if (buffer.Length == 0) - { - return 0; - } - (bool wait, int bytesRead) = TryReadFromBuffer(buffer); if (wait) { @@ -246,11 +239,6 @@ public async ValueTask ReadAsync(Memory buffer, CancellationToken can { cancellationToken.ThrowIfCancellationRequested(); - if (buffer.Length == 0) - { - return 0; - } - (bool wait, int bytesRead) = TryReadFromBuffer(buffer.Span); if (wait) { diff --git a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs index 43451d45301128..7dc9519665c588 100644 --- a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs +++ b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs @@ -114,7 +114,7 @@ public static IEnumerable AllSeekModesAndValue(object value) => from mode in Enum.GetValues() select new object[] { mode, value }; - protected async Task ReadAsync(ReadWriteMode mode, Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + public static async Task ReadAsync(ReadWriteMode mode, Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) { if (mode == ReadWriteMode.SyncByte) { diff --git a/src/libraries/Common/tests/Tests/System/IO/ConnectedStreamsTests.cs b/src/libraries/Common/tests/Tests/System/IO/ConnectedStreamsTests.cs index 0f700bf68f0f71..0d69a89de890c3 100644 --- a/src/libraries/Common/tests/Tests/System/IO/ConnectedStreamsTests.cs +++ b/src/libraries/Common/tests/Tests/System/IO/ConnectedStreamsTests.cs @@ -9,6 +9,7 @@ public class UnidirectionalConnectedStreamsTests : ConnectedStreamConformanceTes { protected override int BufferedSize => StreamBuffer.DefaultMaxBufferSize; protected override bool FlushRequiredToWriteData => false; + protected override bool BlocksOnZeroByteReads => true; protected override Task CreateConnectedStreamsAsync() => Task.FromResult(ConnectedStreams.CreateUnidirectional()); @@ -18,6 +19,7 @@ public class BidirectionalConnectedStreamsTests : ConnectedStreamConformanceTest { protected override int BufferedSize => StreamBuffer.DefaultMaxBufferSize; protected override bool FlushRequiredToWriteData => false; + protected override bool BlocksOnZeroByteReads => true; protected override Task CreateConnectedStreamsAsync() => Task.FromResult(ConnectedStreams.CreateBidirectional()); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs index ff87d5a89f9432..e4ecc124a1e469 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs @@ -37,17 +37,27 @@ public ChunkedEncodingReadStream(HttpConnection connection, HttpResponseMessage public override int Read(Span buffer) { - if (_connection == null || buffer.Length == 0) + if (_connection == null) { - // Response body fully consumed or the caller didn't ask for any data. + // Response body fully consumed return 0; } - // Try to consume from data we already have in the buffer. - int bytesRead = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default); - if (bytesRead > 0) + if (buffer.Length == 0) + { + if (PeekChunkFromConnectionBuffer()) + { + return 0; + } + } + else { - return bytesRead; + // Try to consume from data we already have in the buffer. + int bytesRead = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default); + if (bytesRead > 0) + { + return bytesRead; + } } // Nothing available to consume. Fall back to I/O. @@ -68,7 +78,8 @@ public override int Read(Span buffer) // as the connection buffer. That avoids an unnecessary copy while still reading // the maximum amount we'd otherwise read at a time. Debug.Assert(_connection.RemainingBuffer.Length == 0); - bytesRead = _connection.Read(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining))); + Debug.Assert(buffer.Length != 0); + int bytesRead = _connection.Read(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining))); if (bytesRead == 0) { throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _chunkBytesRemaining)); @@ -81,15 +92,35 @@ public override int Read(Span buffer) return bytesRead; } + if (buffer.Length == 0) + { + // User requested a zero-byte read, and we have no data available in the buffer for processing. + // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read + // for reduced memory consumption when data is not immediately available. + // So, we will issue our own zero-byte read against the underlying stream to allow it to make use of + // optimizations, such as deferring buffer allocation until data is actually available. + _connection.Read(buffer); + } + // We're only here if we need more data to make forward progress. _connection.Fill(); // Now that we have more, see if we can get any response data, and if // we can we're done. - int bytesCopied = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default); - if (bytesCopied > 0) + if (buffer.Length == 0) { - return bytesCopied; + if (PeekChunkFromConnectionBuffer()) + { + return 0; + } + } + else + { + int bytesCopied = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default); + if (bytesCopied > 0) + { + return bytesCopied; + } } } } @@ -102,17 +133,27 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return ValueTask.FromCanceled(cancellationToken); } - if (_connection == null || buffer.Length == 0) + if (_connection == null) { - // Response body fully consumed or the caller didn't ask for any data. + // Response body fully consumed return new ValueTask(0); } - // Try to consume from data we already have in the buffer. - int bytesRead = ReadChunksFromConnectionBuffer(buffer.Span, cancellationRegistration: default); - if (bytesRead > 0) + if (buffer.Length == 0) { - return new ValueTask(bytesRead); + if (PeekChunkFromConnectionBuffer()) + { + return new ValueTask(0); + } + } + else + { + // Try to consume from data we already have in the buffer. + int bytesRead = ReadChunksFromConnectionBuffer(buffer.Span, cancellationRegistration: default); + if (bytesRead > 0) + { + return new ValueTask(bytesRead); + } } // We may have just consumed the remainder of the response (with no actual data @@ -132,7 +173,6 @@ private async ValueTask ReadAsyncCore(Memory buffer, CancellationToke // Should only be called if ReadChunksFromConnectionBuffer returned 0. Debug.Assert(_connection != null); - Debug.Assert(buffer.Length > 0); CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); try @@ -154,6 +194,7 @@ private async ValueTask ReadAsyncCore(Memory buffer, CancellationToke // as the connection buffer. That avoids an unnecessary copy while still reading // the maximum amount we'd otherwise read at a time. Debug.Assert(_connection.RemainingBuffer.Length == 0); + Debug.Assert(buffer.Length != 0); int bytesRead = await _connection.ReadAsync(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining))).ConfigureAwait(false); if (bytesRead == 0) { @@ -167,15 +208,35 @@ private async ValueTask ReadAsyncCore(Memory buffer, CancellationToke return bytesRead; } + if (buffer.Length == 0) + { + // User requested a zero-byte read, and we have no data available in the buffer for processing. + // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read + // for reduced memory consumption when data is not immediately available. + // So, we will issue our own zero-byte read against the underlying stream to allow it to make use of + // optimizations, such as deferring buffer allocation until data is actually available. + await _connection.ReadAsync(buffer).ConfigureAwait(false); + } + // We're only here if we need more data to make forward progress. await _connection.FillAsync(async: true).ConfigureAwait(false); // Now that we have more, see if we can get any response data, and if // we can we're done. - int bytesCopied = ReadChunksFromConnectionBuffer(buffer.Span, ctr); - if (bytesCopied > 0) + if (buffer.Length == 0) { - return bytesCopied; + if (PeekChunkFromConnectionBuffer()) + { + return 0; + } + } + else + { + int bytesCopied = ReadChunksFromConnectionBuffer(buffer.Span, ctr); + if (bytesCopied > 0) + { + return bytesCopied; + } } } } @@ -208,8 +269,7 @@ private async Task CopyToAsyncCore(Stream destination, CancellationToken cancell { while (true) { - ReadOnlyMemory bytesRead = ReadChunkFromConnectionBuffer(int.MaxValue, ctr); - if (bytesRead.Length == 0) + if (ReadChunkFromConnectionBuffer(int.MaxValue, ctr) is not ReadOnlyMemory bytesRead || bytesRead.Length == 0) { break; } @@ -235,18 +295,23 @@ private async Task CopyToAsyncCore(Stream destination, CancellationToken cancell } } + private bool PeekChunkFromConnectionBuffer() + { + return ReadChunkFromConnectionBuffer(maxBytesToRead: 0, cancellationRegistration: default).HasValue; + } + private int ReadChunksFromConnectionBuffer(Span buffer, CancellationTokenRegistration cancellationRegistration) { + Debug.Assert(buffer.Length > 0); int totalBytesRead = 0; while (buffer.Length > 0) { - ReadOnlyMemory bytesRead = ReadChunkFromConnectionBuffer(buffer.Length, cancellationRegistration); - Debug.Assert(bytesRead.Length <= buffer.Length); - if (bytesRead.Length == 0) + if (ReadChunkFromConnectionBuffer(buffer.Length, cancellationRegistration) is not ReadOnlyMemory bytesRead || bytesRead.Length == 0) { break; } + Debug.Assert(bytesRead.Length <= buffer.Length); totalBytesRead += bytesRead.Length; bytesRead.Span.CopyTo(buffer); buffer = buffer.Slice(bytesRead.Length); @@ -254,9 +319,9 @@ private int ReadChunksFromConnectionBuffer(Span buffer, CancellationTokenR return totalBytesRead; } - private ReadOnlyMemory ReadChunkFromConnectionBuffer(int maxBytesToRead, CancellationTokenRegistration cancellationRegistration) + private ReadOnlyMemory? ReadChunkFromConnectionBuffer(int maxBytesToRead, CancellationTokenRegistration cancellationRegistration) { - Debug.Assert(maxBytesToRead > 0 && _connection != null); + Debug.Assert(_connection != null); try { @@ -310,7 +375,7 @@ private ReadOnlyMemory ReadChunkFromConnectionBuffer(int maxBytesToRead, C } int bytesToConsume = Math.Min(maxBytesToRead, (int)Math.Min((ulong)connectionBuffer.Length, _chunkBytesRemaining)); - Debug.Assert(bytesToConsume > 0); + Debug.Assert(bytesToConsume > 0 || maxBytesToRead == 0); _connection.ConsumeFromRemainingBuffer(bytesToConsume); _chunkBytesRemaining -= (ulong)bytesToConsume; @@ -441,8 +506,7 @@ public override async ValueTask DrainAsync(int maxDrainBytes) drainedBytes += _connection.RemainingBuffer.Length; while (true) { - ReadOnlyMemory bytesRead = ReadChunkFromConnectionBuffer(int.MaxValue, ctr); - if (bytesRead.Length == 0) + if (ReadChunkFromConnectionBuffer(int.MaxValue, ctr) is not ReadOnlyMemory bytesRead || bytesRead.Length == 0) { break; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs index 0fd011037b9310..7bddf399572202 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs @@ -18,14 +18,14 @@ public ConnectionCloseReadStream(HttpConnection connection) : base(connection) public override int Read(Span buffer) { HttpConnection? connection = _connection; - if (connection == null || buffer.Length == 0) + if (connection == null) { - // Response body fully consumed or the caller didn't ask for any data + // Response body fully consumed return 0; } int bytesRead = connection.Read(buffer); - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { // We cannot reuse this connection, so close it. _connection = null; @@ -40,9 +40,9 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation CancellationHelper.ThrowIfCancellationRequested(cancellationToken); HttpConnection? connection = _connection; - if (connection == null || buffer.Length == 0) + if (connection == null) { - // Response body fully consumed or the caller didn't ask for any data + // Response body fully consumed return 0; } @@ -69,7 +69,7 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } } - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { // If cancellation is requested and tears down the connection, it could cause the read // to return 0, which would otherwise signal the end of the data, but that would lead diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs index 786f285a93c75e..97f9ddf6b2b5ce 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs @@ -22,9 +22,9 @@ public ContentLengthReadStream(HttpConnection connection, ulong contentLength) : public override int Read(Span buffer) { - if (_connection == null || buffer.Length == 0) + if (_connection == null) { - // Response body fully consumed or the caller didn't ask for any data. + // Response body fully consumed return 0; } @@ -35,7 +35,7 @@ public override int Read(Span buffer) } int bytesRead = _connection.Read(buffer); - if (bytesRead <= 0) + if (bytesRead <= 0 && buffer.Length != 0) { // Unexpected end of response stream. throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _contentBytesRemaining)); @@ -58,9 +58,9 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - if (_connection == null || buffer.Length == 0) + if (_connection == null) { - // Response body fully consumed or the caller didn't ask for any data + // Response body fully consumed return 0; } @@ -94,7 +94,7 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } } - if (bytesRead <= 0) + if (bytesRead == 0 && buffer.Length != 0) { // A cancellation request may have caused the EOF. CancellationHelper.ThrowIfCancellationRequested(cancellationToken); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 5c3fde8ea8fc62..b5d5ffa8b84698 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -1040,8 +1040,6 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken) private (bool wait, int bytesRead) TryReadFromBuffer(Span buffer, bool partOfSyncRead = false) { - Debug.Assert(buffer.Length > 0); - Debug.Assert(!Monitor.IsEntered(SyncObject)); lock (SyncObject) { @@ -1073,11 +1071,6 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken) public int ReadData(Span buffer, HttpResponseMessage responseMessage) { - if (buffer.Length == 0) - { - return 0; - } - (bool wait, int bytesRead) = TryReadFromBuffer(buffer, partOfSyncRead: true); if (wait) { @@ -1092,7 +1085,7 @@ public int ReadData(Span buffer, HttpResponseMessage responseMessage) { _windowManager.AdjustWindow(bytesRead, this); } - else + else if (buffer.Length != 0) { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. MoveTrailersToResponseMessage(responseMessage); @@ -1103,11 +1096,6 @@ public int ReadData(Span buffer, HttpResponseMessage responseMessage) public async ValueTask ReadDataAsync(Memory buffer, HttpResponseMessage responseMessage, CancellationToken cancellationToken) { - if (buffer.Length == 0) - { - return 0; - } - (bool wait, int bytesRead) = TryReadFromBuffer(buffer.Span); if (wait) { @@ -1121,7 +1109,7 @@ public async ValueTask ReadDataAsync(Memory buffer, HttpResponseMessa { _windowManager.AdjustWindow(bytesRead, this); } - else + else if (buffer.Length != 0) { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. MoveTrailersToResponseMessage(responseMessage); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 64f9833e0aa237..b76cf878ae9571 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -1050,7 +1050,7 @@ private int ReadResponseContent(HttpResponseMessage response, Span buffer) { int totalBytesRead = 0; - while (buffer.Length != 0) + do { // Sync over async here -- QUIC implementation does it per-I/O already; this is at least more coarse-grained. if (_responseDataPayloadRemaining <= 0 && !ReadNextDataFrameAsync(response, CancellationToken.None).AsTask().GetAwaiter().GetResult()) @@ -1086,7 +1086,7 @@ private int ReadResponseContent(HttpResponseMessage response, Span buffer) int copyLen = (int)Math.Min(buffer.Length, _responseDataPayloadRemaining); int bytesRead = _stream.Read(buffer.Slice(0, copyLen)); - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _responseDataPayloadRemaining)); } @@ -1100,6 +1100,7 @@ private int ReadResponseContent(HttpResponseMessage response, Span buffer) break; } } + while (buffer.Length != 0); return totalBytesRead; } @@ -1120,7 +1121,7 @@ private async ValueTask ReadResponseContentAsync(HttpResponseMessage respon { int totalBytesRead = 0; - while (buffer.Length != 0) + do { if (_responseDataPayloadRemaining <= 0 && !await ReadNextDataFrameAsync(response, cancellationToken).ConfigureAwait(false)) { @@ -1155,7 +1156,7 @@ private async ValueTask ReadResponseContentAsync(HttpResponseMessage respon int copyLen = (int)Math.Min(buffer.Length, _responseDataPayloadRemaining); int bytesRead = await _stream.ReadAsync(buffer.Slice(0, copyLen), cancellationToken).ConfigureAwait(false); - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _responseDataPayloadRemaining)); } @@ -1169,6 +1170,7 @@ private async ValueTask ReadResponseContentAsync(HttpResponseMessage respon break; } } + while (buffer.Length != 0); return totalBytesRead; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index c7814464a87f66..d75ed2779f0428 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -1708,8 +1708,6 @@ private async ValueTask ReadAsync(Memory destination) private int ReadBuffered(Span destination) { // This is called when reading the response body. - Debug.Assert(destination.Length != 0); - int remaining = _readLength - _readOffset; if (remaining > 0) { @@ -1731,7 +1729,7 @@ private int ReadBuffered(Span destination) // Do a buffered read directly against the underlying stream. Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers."); - int bytesRead = _stream.Read(_readBuffer, 0, _readBuffer.Length); + int bytesRead = _stream.Read(_readBuffer, 0, destination.Length == 0 ? 0 : _readBuffer.Length); if (NetEventSource.Log.IsEnabled()) Trace($"Received {bytesRead} bytes."); _readLength = bytesRead; @@ -1747,7 +1745,9 @@ private ValueTask ReadBufferedAsync(Memory destination) // If the caller provided buffer, and thus the amount of data desired to be read, // is larger than the internal buffer, there's no point going through the internal // buffer, so just do an unbuffered read. - return destination.Length >= _readBuffer.Length ? + // Also avoid avoid using the internal buffer if the user requested a zero-byte read to allow + // underlying streams to efficiently handle such a read (e.g. SslStream defering buffer allocation). + return destination.Length >= _readBuffer.Length || destination.Length == 0 ? ReadAsync(destination) : ReadBufferedAsyncCore(destination); } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs index 44376b01a95b23..7a45db55211c12 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs @@ -23,14 +23,14 @@ public RawConnectionStream(HttpConnection connection) : base(connection) public override int Read(Span buffer) { HttpConnection? connection = _connection; - if (connection == null || buffer.Length == 0) + if (connection == null) { // Response body fully consumed or the caller didn't ask for any data return 0; } int bytesRead = connection.ReadBuffered(buffer); - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { // We cannot reuse this connection, so close it. _connection = null; @@ -45,9 +45,9 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation CancellationHelper.ThrowIfCancellationRequested(cancellationToken); HttpConnection? connection = _connection; - if (connection == null || buffer.Length == 0) + if (connection == null) { - // Response body fully consumed or the caller didn't ask for any data + // Response body fully consumed return 0; } @@ -74,7 +74,7 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } } - if (bytesRead == 0) + if (bytesRead == 0 && buffer.Length != 0) { // A cancellation request may have caused the EOF. CancellationHelper.ThrowIfCancellationRequested(cancellationToken); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamConformanceTests.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamConformanceTests.cs index 37b93847ccccbc..f7f28da5d192a3 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamConformanceTests.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamConformanceTests.cs @@ -80,6 +80,7 @@ public abstract class ResponseConnectedStreamConformanceTests : ConnectedStreamC { protected override Type UnsupportedConcurrentExceptionType => null; protected override bool UsableAfterCanceledReads => false; + protected override bool BlocksOnZeroByteReads => true; protected abstract string GetResponseHeaders(); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs new file mode 100644 index 00000000000000..77acc77b344eff --- /dev/null +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs @@ -0,0 +1,317 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.IO.Tests; +using System.Linq; +using System.Net.Quic; +using System.Net.Quic.Implementations; +using System.Net.Security; +using System.Net.Test.Common; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Http.Functional.Tests +{ + public sealed class Http1CloseResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase + { + protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"; + + protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data); + } + + public sealed class Http1RawResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase + { + protected override string GetResponseHeaders() => "HTTP/1.1 101 Switching Protocols\r\n\r\n"; + + protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data); + } + + public sealed class Http1ContentLengthResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase + { + protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + + protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data); + } + + public sealed class Http1SingleChunkResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase + { + protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + + protected override async Task WriteAsync(Stream stream, byte[] data) + { + await stream.WriteAsync(Encoding.ASCII.GetBytes($"{data.Length:X}\r\n")); + await stream.WriteAsync(data); + await stream.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + } + } + + public sealed class Http1MultiChunkResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase + { + protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + + protected override async Task WriteAsync(Stream stream, byte[] data) + { + for (int i = 0; i < data.Length; i++) + { + await stream.WriteAsync(Encoding.ASCII.GetBytes($"1\r\n")); + await stream.WriteAsync(data.AsMemory(i, 1)); + await stream.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + } + } + } + + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public abstract class Http1ResponseStreamZeroByteReadTestBase + { + protected abstract string GetResponseHeaders(); + + protected abstract Task WriteAsync(Stream stream, byte[] data); + + public static IEnumerable ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream_MemberData() => + from readMode in Enum.GetValues() + .Where(mode => mode != StreamConformanceTests.ReadWriteMode.SyncByte) // Can't test zero-byte reads with ReadByte + from useSsl in new[] { true, false } + select new object[] { readMode, useSsl }; + + [Theory] + [MemberData(nameof(ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream_MemberData))] + public async Task ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream(StreamConformanceTests.ReadWriteMode readMode, bool useSsl) + { + (Stream httpConnection, Stream server) = ConnectedStreams.CreateBidirectional(4096, int.MaxValue); + try + { + var sawZeroByteRead = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + httpConnection = new ReadInterceptStream(httpConnection, read => + { + if (read == 0) + { + sawZeroByteRead.TrySetResult(); + } + }); + + using var handler = new SocketsHttpHandler + { + ConnectCallback = delegate { return ValueTask.FromResult(httpConnection); } + }; + handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; }; + + using var client = new HttpClient(handler); + + Task clientTask = client.GetAsync($"http{(useSsl ? "s" : "")}://doesntmatter", HttpCompletionOption.ResponseHeadersRead); + + if (useSsl) + { + var sslStream = new SslStream(server, false, delegate { return true; }); + server = sslStream; + + using (X509Certificate2 cert = Test.Common.Configuration.Certificates.GetServerCertificate()) + { + await ((SslStream)server).AuthenticateAsServerAsync( + cert, + clientCertificateRequired: true, + enabledSslProtocols: SslProtocols.Tls12, + checkCertificateRevocation: false).WaitAsync(TimeSpan.FromSeconds(10)); + } + } + + await ResponseConnectedStreamConformanceTests.ReadHeadersAsync(server).WaitAsync(TimeSpan.FromSeconds(10)); + await server.WriteAsync(Encoding.ASCII.GetBytes(GetResponseHeaders())); + + using HttpResponseMessage response = await clientTask.WaitAsync(TimeSpan.FromSeconds(10)); + using Stream clientStream = response.Content.ReadAsStream(); + Assert.False(sawZeroByteRead.Task.IsCompleted); + + Task zeroByteReadTask = Task.Run(() => StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty(), 0, 0, CancellationToken.None) ); + Assert.False(zeroByteReadTask.IsCompleted); + + // The zero-byte read should block until data is actually available + await sawZeroByteRead.Task.WaitAsync(TimeSpan.FromSeconds(10)); + Assert.False(zeroByteReadTask.IsCompleted); + + byte[] data = Encoding.UTF8.GetBytes("Hello"); + await WriteAsync(server, data); + await server.FlushAsync(); + + Assert.Equal(0, await zeroByteReadTask.WaitAsync(TimeSpan.FromSeconds(10))); + + // Now that data is available, a zero-byte read should complete synchronously + zeroByteReadTask = StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty(), 0, 0, CancellationToken.None); + Assert.True(zeroByteReadTask.IsCompleted); + Assert.Equal(0, await zeroByteReadTask); + + var readBuffer = new byte[10]; + int read = 0; + while (read < data.Length) + { + read += await StreamConformanceTests.ReadAsync(readMode, clientStream, readBuffer, read, readBuffer.Length - read, CancellationToken.None).WaitAsync(TimeSpan.FromSeconds(10)); + } + + Assert.Equal(data.Length, read); + Assert.Equal(data, readBuffer.AsSpan(0, read).ToArray()); + } + finally + { + httpConnection.Dispose(); + server.Dispose(); + } + } + + private sealed class ReadInterceptStream : DelegatingStream + { + private readonly Action _readCallback; + + public ReadInterceptStream(Stream innerStream, Action readCallback) + : base(innerStream) + { + _readCallback = readCallback; + } + + public override int Read(Span buffer) + { + _readCallback(buffer.Length); + return base.Read(buffer); + } + + public override int Read(byte[] buffer, int offset, int count) + { + _readCallback(count); + return base.Read(buffer, offset, count); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + _readCallback(buffer.Length); + return base.ReadAsync(buffer, cancellationToken); + } + } + } + + public sealed class Http1ResponseStreamZeroByteReadTest : ResponseStreamZeroByteReadTestBase + { + public Http1ResponseStreamZeroByteReadTest(ITestOutputHelper output) : base(output) { } + + protected override Version UseVersion => HttpVersion.Version11; + } + + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public sealed class Http2ResponseStreamZeroByteReadTest : ResponseStreamZeroByteReadTestBase + { + public Http2ResponseStreamZeroByteReadTest(ITestOutputHelper output) : base(output) { } + + protected override Version UseVersion => HttpVersion.Version20; + } + + [ConditionalClass(typeof(HttpClientHandlerTestBase), nameof(IsMsQuicSupported))] + public sealed class Http3ResponseStreamZeroByteReadTest_MsQuic : ResponseStreamZeroByteReadTestBase + { + public Http3ResponseStreamZeroByteReadTest_MsQuic(ITestOutputHelper output) : base(output) { } + + protected override Version UseVersion => HttpVersion.Version30; + + protected override QuicImplementationProvider UseQuicImplementationProvider => QuicImplementationProviders.MsQuic; + } + + [ConditionalClass(typeof(HttpClientHandlerTestBase), nameof(IsMockQuicSupported))] + public sealed class Http3ResponseStreamZeroByteReadTest_Mock : ResponseStreamZeroByteReadTestBase + { + public Http3ResponseStreamZeroByteReadTest_Mock(ITestOutputHelper output) : base(output) { } + + protected override Version UseVersion => HttpVersion.Version30; + + protected override QuicImplementationProvider UseQuicImplementationProvider => QuicImplementationProviders.Mock; + } + + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public abstract class ResponseStreamZeroByteReadTestBase : HttpClientHandlerTestBase + { + public ResponseStreamZeroByteReadTestBase(ITestOutputHelper output) : base(output) { } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ZeroByteRead_BlocksUntilDataIsAvailable(bool async) + { + var zeroByteReadIssued = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + HttpRequestMessage request = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true); + + using HttpClient client = CreateHttpClient(); + using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + using Stream responseStream = await response.Content.ReadAsStreamAsync(); + + var responseBuffer = new byte[1]; + Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer)); + Assert.Equal(42, responseBuffer[0]); + + Task zeroByteReadTask = ReadAsync(async, responseStream, Array.Empty()); + Assert.False(zeroByteReadTask.IsCompleted); + + zeroByteReadIssued.SetResult(); + Assert.Equal(0, await zeroByteReadTask); + Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty())); + + Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer)); + Assert.Equal(1, responseBuffer[0]); + + Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty())); + + Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer)); + Assert.Equal(2, responseBuffer[0]); + + zeroByteReadTask = ReadAsync(async, responseStream, Array.Empty()); + Assert.False(zeroByteReadTask.IsCompleted); + + zeroByteReadIssued.SetResult(); + Assert.Equal(0, await zeroByteReadTask); + Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty())); + + Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer)); + Assert.Equal(3, responseBuffer[0]); + + Assert.Equal(0, await ReadAsync(async, responseStream, responseBuffer)); + }, + async server => + { + await server.AcceptConnectionAsync(async connection => + { + await connection.ReadRequestDataAsync(); + + await connection.SendResponseAsync(headers: new[] { new HttpHeaderData("Content-Length", "4") }, isFinal: false); + + await connection.SendResponseBodyAsync(new byte[] { 42 }, isFinal: false); + + await zeroByteReadIssued.Task; + zeroByteReadIssued = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await connection.SendResponseBodyAsync(new byte[] { 1, 2 }, isFinal: false); + + await zeroByteReadIssued.Task; + + await connection.SendResponseBodyAsync(new byte[] { 3 }, isFinal: true); + }); + }); + + static Task ReadAsync(bool async, Stream stream, byte[] buffer) + { + if (async) + { + return stream.ReadAsync(buffer).AsTask(); + } + else + { + return Task.Run(() => stream.Read(buffer)); + } + } + } + } +} diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj index 29bbd6f8f083e9..343267df792ed5 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj @@ -186,6 +186,7 @@ + QuicImplementationProviders.Mock; + protected override bool BlocksOnZeroByteReads => true; } [ConditionalClass(typeof(QuicTestBase), nameof(QuicTestBase.IsSupported))]