diff --git a/sdk/clientcore/http-netty4/pom.xml b/sdk/clientcore/http-netty4/pom.xml index 6ba17179a4ba..d3306077c811 100644 --- a/sdk/clientcore/http-netty4/pom.xml +++ b/sdk/clientcore/http-netty4/pom.xml @@ -205,6 +205,12 @@ 2.5.2 test + + org.mockito + mockito-core + 4.11.0 + test + diff --git a/sdk/clientcore/http-netty4/spotbugs-exclude.xml b/sdk/clientcore/http-netty4/spotbugs-exclude.xml index 9b8d71f16243..5eee4b5d1967 100644 --- a/sdk/clientcore/http-netty4/spotbugs-exclude.xml +++ b/sdk/clientcore/http-netty4/spotbugs-exclude.xml @@ -9,6 +9,7 @@ + @@ -61,4 +62,18 @@ + + + + + + + + + + + + + + diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClient.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClient.java index 90fbafd96b16..66dea873d2c1 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClient.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClient.java @@ -7,6 +7,7 @@ import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpHeaderName; import io.clientcore.core.http.models.HttpRequest; +import io.clientcore.core.http.models.ProxyOptions; import io.clientcore.core.http.models.Response; import io.clientcore.core.http.models.ServerSentEventListener; import io.clientcore.core.instrumentation.logging.ClientLogger; @@ -20,8 +21,10 @@ import io.clientcore.http.netty4.implementation.ChannelInitializationProxyHandler; import io.clientcore.http.netty4.implementation.Netty4AlpnHandler; import io.clientcore.http.netty4.implementation.Netty4ChannelBinaryData; -import io.clientcore.http.netty4.implementation.Netty4EagerConsumeChannelHandler; -import io.clientcore.http.netty4.implementation.Netty4HandlerNames; +import io.clientcore.http.netty4.implementation.Netty4ConnectionPool; +import io.clientcore.http.netty4.implementation.Netty4ConnectionPoolKey; +import io.clientcore.http.netty4.implementation.Netty4PipelineCleanupEvent; +import io.clientcore.http.netty4.implementation.Netty4PipelineCleanupHandler; import io.clientcore.http.netty4.implementation.Netty4ProgressAndTimeoutHandler; import io.clientcore.http.netty4.implementation.Netty4ResponseHandler; import io.clientcore.http.netty4.implementation.Netty4SslInitializationHandler; @@ -32,33 +35,47 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; -import io.netty.handler.codec.http2.Http2SecurityUtil; import io.netty.handler.proxy.ProxyHandler; -import io.netty.handler.ssl.ApplicationProtocolConfig; -import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; -import io.netty.handler.ssl.SslProvider; -import io.netty.handler.ssl.SupportedCipherSuiteFilter; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import javax.net.ssl.SSLException; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.nio.channels.ClosedChannelException; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import static io.clientcore.core.utils.ServerSentEventUtils.attemptRetry; import static io.clientcore.core.utils.ServerSentEventUtils.processTextEventStream; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.ALPN; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP2_GOAWAY; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_CODEC; import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_RESPONSE; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PIPELINE_CLEANUP; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.POOL_CONNECTION_HEALTH; import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROGRESS_AND_TIMEOUT; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROXY; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROXY_EXCEPTION_WARNING_SUPPRESSION; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.SSL; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.SSL_INITIALIZER; import static io.clientcore.http.netty4.implementation.Netty4Utility.awaitLatch; +import static io.clientcore.http.netty4.implementation.Netty4Utility.buildSslContext; import static io.clientcore.http.netty4.implementation.Netty4Utility.createCodec; +import static io.clientcore.http.netty4.implementation.Netty4Utility.createHttp2Codec; import static io.clientcore.http.netty4.implementation.Netty4Utility.sendHttp11Request; +import static io.clientcore.http.netty4.implementation.Netty4Utility.sendHttp2Request; import static io.clientcore.http.netty4.implementation.Netty4Utility.setOrSuppressError; /** @@ -73,34 +90,97 @@ class NettyHttpClient implements HttpClient { private static final String NO_LISTENER_ERROR_MESSAGE = "No ServerSentEventListener attached to HttpRequest to handle the text/event-stream response"; - private final Bootstrap bootstrap; - private final Consumer sslContextModifier; + private final EventLoopGroup eventLoopGroup; + private final Netty4ConnectionPool connectionPool; + private final ProxyOptions proxyOptions; private final ChannelInitializationProxyHandler channelInitializationProxyHandler; - private final AtomicReference> proxyChallenges; private final long readTimeoutMillis; private final long responseTimeoutMillis; private final long writeTimeoutMillis; + + private final Bootstrap bootstrap; + private final Consumer sslContextModifier; private final HttpProtocolVersion maximumHttpVersion; - NettyHttpClient(Bootstrap bootstrap, Consumer sslContextModifier, - HttpProtocolVersion maximumHttpVersion, ChannelInitializationProxyHandler channelInitializationProxyHandler, - long readTimeoutMillis, long responseTimeoutMillis, long writeTimeoutMillis) { + NettyHttpClient(Bootstrap bootstrap, EventLoopGroup eventLoopGroup, Netty4ConnectionPool connectionPool, + ProxyOptions proxyOptions, ChannelInitializationProxyHandler channelInitializationProxyHandler, + Consumer sslContextModifier, HttpProtocolVersion maximumHttpVersion, long readTimeoutMillis, + long responseTimeoutMillis, long writeTimeoutMillis) { this.bootstrap = bootstrap; + this.eventLoopGroup = eventLoopGroup; + this.connectionPool = connectionPool; + this.proxyOptions = proxyOptions; + this.channelInitializationProxyHandler = channelInitializationProxyHandler; this.sslContextModifier = sslContextModifier; this.maximumHttpVersion = maximumHttpVersion; - this.channelInitializationProxyHandler = channelInitializationProxyHandler; - this.proxyChallenges = new AtomicReference<>(); this.readTimeoutMillis = readTimeoutMillis; this.responseTimeoutMillis = responseTimeoutMillis; this.writeTimeoutMillis = writeTimeoutMillis; } Bootstrap getBootstrap() { - return bootstrap; + return connectionPool != null ? connectionPool.getBootstrap() : bootstrap; } @Override public Response send(HttpRequest request) { + return connectionPool != null ? sendWithConnectionPool(request) : sendWithoutConnectionPool(request); + } + + private Response sendWithConnectionPool(HttpRequest request) { + final URI uri = request.getUri(); + final boolean isHttps = "https".equalsIgnoreCase(uri.getScheme()); + final int port = uri.getPort() == -1 ? (isHttps ? 443 : 80) : uri.getPort(); + final SocketAddress finalDestination = new InetSocketAddress(uri.getHost(), port); + + final Netty4ConnectionPoolKey connectionPoolKey = constructConnectionPoolKey(finalDestination, isHttps); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference responseReference = new AtomicReference<>(); + final AtomicReference errorReference = new AtomicReference<>(); + + Future channelFuture = connectionPool.acquire(connectionPoolKey, isHttps); + + channelFuture.addListener((GenericFutureListener>) future -> { + if (!future.isSuccess()) { + LOGGER.atError().setThrowable(future.cause()).log("Failed connection."); + errorReference.set(future.cause()); + latch.countDown(); + return; + } + + final Channel channel = future.getNow(); + try { + configurePooledRequestPipeline(channel, request, responseReference, errorReference, latch, isHttps); + } catch (Exception e) { + // An exception occurred during the pipeline setup. + // We fire the exception through the pipeline to trigger the cleanup handler, + // which will ensure the channel is properly closed and not returned to the pool. + setOrSuppressError(errorReference, e); + if (channel.isActive()) { + channel.pipeline().fireExceptionCaught(e); + } + latch.countDown(); + } + }); + + awaitLatch(latch); + + ResponseStateInfo info = responseReference.get(); + if (info != null) { + return createResponse(request, info); + } + + if (errorReference.get() != null) { + throw LOGGER.throwableAtError().log(errorReference.get(), CoreException::from); + } else { + throw LOGGER.throwableAtError() + .log("The request pipeline completed without producing a response or an error.", + IllegalStateException::new); + } + } + + private Response sendWithoutConnectionPool(HttpRequest request) { URI uri = request.getUri(); String host = uri.getHost(); int port = uri.getPort() == -1 ? ("https".equalsIgnoreCase(uri.getScheme()) ? 443 : 80) : uri.getPort(); @@ -113,6 +193,7 @@ public Response send(HttpRequest request) { AtomicReference responseReference = new AtomicReference<>(); AtomicReference errorReference = new AtomicReference<>(); + AtomicReference> proxyChallenges = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); // Configure an immutable ChannelInitializer in the builder with everything that can be added on a non-per @@ -130,47 +211,21 @@ protected void initChannel(Channel channel) throws SSLException { } }); - channel.pipeline().addFirst(Netty4HandlerNames.PROXY, proxyHandler); + ChannelPipeline pipeline = channel.pipeline(); + pipeline.addFirst(PROXY, proxyHandler); + pipeline.addAfter(PROXY, PROXY_EXCEPTION_WARNING_SUPPRESSION, + Netty4ConnectionPool.SuppressProxyConnectExceptionWarningHandler.INSTANCE); } // Add SSL handling if the request is HTTPS. if (isHttps) { - SslContextBuilder sslContextBuilder - = SslContextBuilder.forClient().endpointIdentificationAlgorithm("HTTPS"); - if (maximumHttpVersion == HttpProtocolVersion.HTTP_2) { - // If HTTP/2 is the maximum version, we need to ensure that ALPN is enabled. - SslProvider sslProvider = SslContext.defaultClientProvider(); - ApplicationProtocolConfig.SelectorFailureBehavior selectorBehavior; - ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedBehavior; - if (sslProvider == SslProvider.JDK) { - selectorBehavior = ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT; - selectedBehavior = ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT; - } else { - // Netty OpenSslContext doesn't support FATAL_ALERT, use NO_ADVERTISE and ACCEPT - // instead. - selectorBehavior = ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE; - selectedBehavior = ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT; - } - - sslContextBuilder.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) - .applicationProtocolConfig(new ApplicationProtocolConfig( - ApplicationProtocolConfig.Protocol.ALPN, selectorBehavior, selectedBehavior, - ApplicationProtocolNames.HTTP_2, ApplicationProtocolNames.HTTP_1_1)); - } - if (sslContextModifier != null) { - // Allow the caller to modify the SslContextBuilder before it is built. - sslContextModifier.accept(sslContextBuilder); - } - - SslContext ssl = sslContextBuilder.build(); + SslContext ssl = buildSslContext(maximumHttpVersion, sslContextModifier); // SSL handling is added last here. This is done as proxying could require SSL handling too. - channel.pipeline().addLast(Netty4HandlerNames.SSL, ssl.newHandler(channel.alloc(), host, port)); - channel.pipeline() - .addLast(Netty4HandlerNames.SSL_INITIALIZER, new Netty4SslInitializationHandler()); + channel.pipeline().addLast(SSL, ssl.newHandler(channel.alloc(), host, port)); + channel.pipeline().addLast(SSL_INITIALIZER, new Netty4SslInitializationHandler()); channel.pipeline() - .addLast(Netty4HandlerNames.ALPN, - new Netty4AlpnHandler(request, responseReference, errorReference, latch)); + .addLast(ALPN, new Netty4AlpnHandler(request, responseReference, errorReference, latch)); } } }); @@ -194,8 +249,8 @@ protected void initChannel(Channel channel) throws SSLException { // Only add CoreProgressAndTimeoutHandler if it will do anything, ex it is reporting progress or is // applying timeouts. - // This is done to keep the ChannelPipeline shorter, therefore more performant, if this would - // effectively be a no-op. + // This is done to keep the ChannelPipeline shorter, therefore more performant if this + // effectively is a no-op. if (addProgressAndTimeoutHandler) { channel.pipeline() .addLast(PROGRESS_AND_TIMEOUT, new Netty4ProgressAndTimeoutHandler(progressReporter, @@ -204,7 +259,7 @@ protected void initChannel(Channel channel) throws SSLException { Throwable earlyError = errorReference.get(); if (earlyError != null) { - // If an error occurred between the connect and the request being sent, don't proceed with sending + // If an error occurred between the connecting and the request being sent, don't proceed with sending // the request. latch.countDown(); return; @@ -238,7 +293,7 @@ protected void initChannel(Channel channel) throws SSLException { channel.pipeline().addLast(HTTP_RESPONSE, responseHandler); String addBefore = addProgressAndTimeoutHandler ? PROGRESS_AND_TIMEOUT : HTTP_RESPONSE; - channel.pipeline().addBefore(addBefore, Netty4HandlerNames.HTTP_CODEC, createCodec()); + channel.pipeline().addBefore(addBefore, HTTP_CODEC, createCodec()); sendHttp11Request(request, channel, errorReference) .addListener((ChannelFutureListener) sendListener -> { @@ -260,67 +315,151 @@ protected void initChannel(Channel channel) throws SSLException { throw LOGGER.throwableAtError().log(errorReference.get(), CoreException::from); } - Response response; - if (info.isChannelConsumptionComplete()) { - // The network response is already complete, handle creating our Response based on the request method and - // response headers. - BinaryData body = BinaryData.empty(); - ByteArrayOutputStream eagerContent = info.getEagerContent(); - if (info.getResponseBodyHandling() != ResponseBodyHandling.IGNORE && eagerContent.size() > 0) { - // Set the response body as the first HttpContent received if the request wasn't a HEAD request and - // there was body content. - body = BinaryData.fromBytes(eagerContent.toByteArray()); + return createResponse(request, info); + } + + private void configurePooledRequestPipeline(Channel channel, HttpRequest request, + AtomicReference responseReference, AtomicReference errorReference, + CountDownLatch latch, boolean isHttps) { + + ReentrantLock lock = channel.attr(Netty4ConnectionPool.CHANNEL_LOCK).get(); + lock.lock(); + try { + channel.config().setAutoRead(false); + + // It's possible that the channel was closed between the time it was acquired and now. + // This check ensures that we don't try to add handlers to a closed channel. + // Read handlers are responsible after this check for not being added in a closed channel. + if (!channel.isActive()) { + LOGGER.atWarning().log("Channel acquired from the pool is inactive, failing the request."); + setOrSuppressError(errorReference, new ClosedChannelException()); + latch.countDown(); + return; } - response = new Response<>(request, info.getStatusCode(), info.getHeaders(), body); - } else { - // Otherwise we aren't finished, handle the remaining content according to the documentation in - // 'channelRead()'. - BinaryData body = BinaryData.empty(); - ResponseBodyHandling bodyHandling = info.getResponseBodyHandling(); - Channel channel = info.getResponseChannel(); - if (bodyHandling == ResponseBodyHandling.IGNORE) { - // We're ignoring the response content. - CountDownLatch drainLatch = new CountDownLatch(1); - channel.pipeline().addLast(new Netty4EagerConsumeChannelHandler(drainLatch, ignored -> { - }, info.isHttp2())); - channel.config().setAutoRead(true); - awaitLatch(drainLatch); - } else if (bodyHandling == ResponseBodyHandling.STREAM) { - // Body streaming uses a special BinaryData that tracks the firstContent read and the Channel it came - // from so it can be consumed when the BinaryData is being used. - // autoRead should have been disabled already but lets make sure that it is. - channel.config().setAutoRead(false); - String contentLength = info.getHeaders().getValue(HttpHeaderName.CONTENT_LENGTH); - Long length = null; - if (!CoreUtils.isNullOrEmpty(contentLength)) { - try { - length = Long.parseLong(contentLength); - } catch (NumberFormatException ignored) { - // Ignore, we'll just read until the channel is closed. - } + final Object pipelineOwnerToken = new Object(); + channel.attr(Netty4ConnectionPool.PIPELINE_OWNER_TOKEN).set(pipelineOwnerToken); + ChannelPipeline pipeline = channel.pipeline(); + + HttpProtocolVersion protocol = channel.attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).get(); + boolean isHttp2 = protocol == HttpProtocolVersion.HTTP_2; + + if (protocol == null) { + // Ideally, this should never happen, but as a safeguard. + setOrSuppressError(errorReference, new IllegalStateException("Channel from pool is missing protocol.")); + latch.countDown(); + return; + } + + if (isHttp2) { + // For H2 (which is always HTTPS), the codec is persistent. + // Add it only if it's not already there (first request). + if (pipeline.get(HTTP_CODEC) == null) { + pipeline.addAfter(SSL, HTTP_CODEC, createHttp2Codec()); + pipeline.addAfter(HTTP_CODEC, HTTP2_GOAWAY, new Netty4ConnectionPool.Http2GoAwayHandler()); } + } else { // HTTP/1.1 (can be HTTP or HTTPS) + // For H1, the codec is transient and must be added for every request. + // The cleanup handler is responsible for removing it. + String after = isHttps ? SSL : POOL_CONNECTION_HEALTH; + pipeline.addAfter(after, HTTP_CODEC, createCodec()); + } + + ProgressReporter progressReporter = request.getContext() == null + ? null + : (ProgressReporter) request.getContext().getMetadata("progressReporter"); + + boolean addProgressAndTimeoutHandler = progressReporter != null + || writeTimeoutMillis > 0 + || responseTimeoutMillis > 0 + || readTimeoutMillis > 0; - body = new Netty4ChannelBinaryData(info.getEagerContent(), channel, length, info.isHttp2()); + Netty4ResponseHandler responseHandler + = new Netty4ResponseHandler(request, responseReference, errorReference, latch); + + if (addProgressAndTimeoutHandler) { + Netty4ProgressAndTimeoutHandler progressAndTimeoutHandler = new Netty4ProgressAndTimeoutHandler( + progressReporter, writeTimeoutMillis, responseTimeoutMillis, readTimeoutMillis); + + pipeline.addAfter(HTTP_CODEC, PROGRESS_AND_TIMEOUT, progressAndTimeoutHandler); + pipeline.addAfter(PROGRESS_AND_TIMEOUT, HTTP_RESPONSE, responseHandler); } else { - // All cases otherwise assume BUFFER. - CountDownLatch drainLatch = new CountDownLatch(1); - channel.pipeline().addLast(new Netty4EagerConsumeChannelHandler(drainLatch, buf -> { - try { - buf.readBytes(info.getEagerContent(), buf.readableBytes()); - } catch (IOException ex) { - throw LOGGER.throwableAtError().log(ex, CoreException::from); - } - }, info.isHttp2())); - channel.config().setAutoRead(true); - awaitLatch(drainLatch); + pipeline.addAfter(HTTP_CODEC, HTTP_RESPONSE, responseHandler); + } + + pipeline.addLast(PIPELINE_CLEANUP, + new Netty4PipelineCleanupHandler(connectionPool, errorReference, pipelineOwnerToken)); + + channel.eventLoop().execute(() -> { + if (isHttp2) { + sendHttp2Request(request, channel, errorReference, latch); + } else { // HTTP/1.1 + send(request, channel, errorReference, latch); + } + }); + } finally { + lock.unlock(); + } + } + + private void send(HttpRequest request, Channel channel, AtomicReference errorReference, + CountDownLatch latch) { + sendHttp11Request(request, channel, errorReference).addListener(f -> { + if (f.isSuccess()) { + channel.read(); + } else { + setOrSuppressError(errorReference, f.cause()); + channel.pipeline().fireExceptionCaught(f.cause()); + latch.countDown(); + } + }); + } + + private Response createResponse(HttpRequest request, ResponseStateInfo info) { + BinaryData body; + Response response; + Channel channelToCleanup = info.getResponseChannel(); - body = BinaryData.fromBytes(info.getEagerContent().toByteArray()); + channelToCleanup.eventLoop().execute(() -> { + if (channelToCleanup.pipeline().get(Netty4ResponseHandler.class) != null) { + channelToCleanup.pipeline().remove(Netty4ResponseHandler.class); } + }); - response = new Response<>(request, info.getStatusCode(), info.getHeaders(), body); + final Runnable cleanupTask = () -> { + if (connectionPool != null) { + channelToCleanup.pipeline().fireUserEventTriggered(Netty4PipelineCleanupEvent.CLEANUP_PIPELINE); + } else { + channelToCleanup.close(); + } + }; + + if (info.isChannelConsumptionComplete()) { + ByteArrayOutputStream eagerContent = info.getEagerContent(); + + body = (info.getResponseBodyHandling() != ResponseBodyHandling.IGNORE + && eagerContent != null + && eagerContent.size() > 0) ? BinaryData.fromBytes(eagerContent.toByteArray()) : BinaryData.empty(); + + channelToCleanup.eventLoop().execute(cleanupTask); + } else { + // For all other cases, create a streaming response body. + // This delegates all body consumption and cleanup logic to Netty4ChannelBinaryData. + String contentLength = info.getHeaders().getValue(HttpHeaderName.CONTENT_LENGTH); + Long length = null; + if (!CoreUtils.isNullOrEmpty(contentLength)) { + try { + length = Long.parseLong(contentLength); + } catch (NumberFormatException ignored) { + // Ignore, we'll just read until the channel is closed. + } + } + body = new Netty4ChannelBinaryData(info.getEagerContent(), info.getResponseChannel(), length, + info.isHttp2(), cleanupTask); } + response = new Response<>(request, info.getStatusCode(), info.getHeaders(), body); + if (response.getValue() != BinaryData.empty() && ServerSentEventUtils .isTextEventStreamContentType(response.getHeaders().getValue(HttpHeaderName.CONTENT_TYPE))) { @@ -338,6 +477,7 @@ protected void initChannel(Channel channel) throws SSLException { // If an error occurred or we want to reconnect if (!Thread.currentThread().isInterrupted() && attemptRetry(serverSentResult, request)) { + response.close(); return this.send(request); } @@ -347,17 +487,23 @@ protected void initChannel(Channel channel) throws SSLException { throw LOGGER.throwableAtError().log(ex, CoreException::from); } } else { + response.close(); throw LOGGER.throwableAtError().log(NO_LISTENER_ERROR_MESSAGE, IllegalStateException::new); } } - return response; } public void close() { - EventLoopGroup group = bootstrap.config().group(); - if (group != null) { - group.shutdownGracefully(); + if (connectionPool != null) { + try { + connectionPool.close(); + } catch (IOException e) { + LOGGER.atWarning().setThrowable(e).log("Failed to close the Netty Connection pool."); + } + } + if (eventLoopGroup != null && !eventLoopGroup.isShuttingDown()) { + eventLoopGroup.shutdownGracefully(); } } @@ -367,4 +513,26 @@ private static BinaryData createBodyFromServerSentResult(ServerSentResult server : BinaryData.empty(); } + private Netty4ConnectionPoolKey constructConnectionPoolKey(SocketAddress finalDestination, boolean isHttps) { + final Netty4ConnectionPoolKey key; + + final boolean useProxy = channelInitializationProxyHandler.test(finalDestination); + if (useProxy) { + SocketAddress proxyAddress = proxyOptions.getAddress(); + if (isHttps) { + // For proxied HTTPS, the pool is keyed by the unique combination of the proxy + // and the final destination. This creates dedicated pools for each tunneled destination. + key = new Netty4ConnectionPoolKey(proxyAddress, finalDestination); + } else { + // For proxied plain HTTP, the pool is keyed only by the proxy address. + // This allows reusing the same connection to the proxy for different final destinations. + key = new Netty4ConnectionPoolKey(proxyAddress, proxyAddress); + } + } else { + key = new Netty4ConnectionPoolKey(finalDestination, finalDestination); + } + + return key; + } + } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientBuilder.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientBuilder.java index b74853327e3a..68c163abd001 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientBuilder.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientBuilder.java @@ -7,8 +7,10 @@ import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.ProxyOptions; import io.clientcore.core.instrumentation.logging.ClientLogger; +import io.clientcore.core.instrumentation.logging.LoggingEvent; import io.clientcore.core.utils.configuration.Configuration; import io.clientcore.http.netty4.implementation.ChannelInitializationProxyHandler; +import io.clientcore.http.netty4.implementation.Netty4ConnectionPool; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelOption; @@ -135,6 +137,13 @@ private static Class getChannelClass(String className) private Duration writeTimeout; private HttpProtocolVersion maximumHttpVersion = HttpProtocolVersion.HTTP_2; + // --- Connection Pool Configuration --- + private int connectionPoolSize = 1000; + private Duration connectionIdleTimeout = Duration.ofSeconds(60); + private Duration maxConnectionLifetime; + private Duration pendingAcquireTimeout = Duration.ofSeconds(60); // Default wait time for a connection + private int maxPendingAcquires = 10_000; // Default pending queue size + /** * Creates a new instance of {@link NettyHttpClientBuilder}. */ @@ -281,6 +290,86 @@ public NettyHttpClientBuilder maximumHttpVersion(HttpProtocolVersion httpVersion return this; } + /** + * Sets the maximum number of connections allowed per remote address in the connection pool. + *

+ * If not set, a default value of 1000 is used. + *

+ * A value of {@code 0} or less will disable connection pooling, and hence each request will + * get a newly created connection. + * + * @param connectionPoolSize The maximum number of connections. + * @return The updated builder. + */ + public NettyHttpClientBuilder connectionPoolSize(int connectionPoolSize) { + this.connectionPoolSize = connectionPoolSize; + return this; + } + + /** + * Sets the maximum time a connection can remain idle in the pool before it is closed and removed. + *

+ * If not set, a default value of 60 seconds is used. + *

+ * A {@link Duration} of zero or less will make the connections never expire. Note: While this is + * provided as an option, it is not recommended for most use cases, as it can lead to + * request failures if network intermediaries (like load balancers or firewalls) silently drop idle + * connections. + * + * @param connectionIdleTimeout The idle timeout duration. + * @return The updated builder. + */ + public NettyHttpClientBuilder connectionIdleTimeout(Duration connectionIdleTimeout) { + this.connectionIdleTimeout = connectionIdleTimeout; + return this; + } + + /** + * Sets the maximum time a connection is allowed to exist. + *

+ * By default, connections have no lifetime limit and can be used indefinitely. + *

+ * After this time is met or exceeded, the connection will be closed upon release. A {@link Duration} of zero or + * less, or a null value, will also result in connections having no lifetime limit. + * + * @param maxConnectionLifetime The maximum connection lifetime. + * @return The updated builder. + */ + public NettyHttpClientBuilder maxConnectionLifetime(Duration maxConnectionLifetime) { + this.maxConnectionLifetime = maxConnectionLifetime; + return this; + } + + /** + * Sets the maximum time to wait for a connection from the pool. + *

+ * If not set, a default value of 60 seconds is used. + * @param pendingAcquireTimeout The timeout for pending acquires. + * @return The updated builder. + */ + public NettyHttpClientBuilder pendingAcquireTimeout(Duration pendingAcquireTimeout) { + this.pendingAcquireTimeout = pendingAcquireTimeout; + return this; + } + + /** + * Sets the maximum number of requests that can be queued waiting for a connection. + *

+ * This limit is applied on a per-route (per-host) basis. + * If not set, a default value of 10_000 is used. + * + * @param maxPendingAcquires The maximum number of pending acquires. + * @return The updated builder. + */ + public NettyHttpClientBuilder maxPendingAcquires(int maxPendingAcquires) { + if (maxPendingAcquires <= 0) { + throw LOGGER.throwableAtError() + .log("maxPendingAcquires must be greater than 0", IllegalArgumentException::new); + } + this.maxPendingAcquires = maxPendingAcquires; + return this; + } + /** * Builds the NettyHttpClient. * @@ -293,18 +382,33 @@ public HttpClient build() { = getChannelClass(this.channelClass, group.getClass(), IS_EPOLL_AVAILABLE, IS_KQUEUE_AVAILABLE); // Leave breadcrumbs about the NettyHttpClient configuration, in case troubleshooting is needed. - LOGGER.atVerbose() + LoggingEvent loggingEvent = LOGGER.atVerbose() .addKeyValue("customEventLoopGroup", eventLoopGroup != null) .addKeyValue("eventLoopGroupClass", group.getClass()) .addKeyValue("customChannelClass", this.channelClass != null) - .addKeyValue("channelClass", channelClass) - .log("NettyHttpClient was built with these configurations."); + .addKeyValue("channelClass", channelClass); + + if (connectionPoolSize > 0) { + loggingEvent.addKeyValue("connectionPoolSize", this.connectionPoolSize) + .addKeyValue("connectionIdleTimeout", this.connectionIdleTimeout) + .addKeyValue("maxConnectionLifetime", this.maxConnectionLifetime) + .addKeyValue("pendingAcquireTimeout", this.pendingAcquireTimeout) + .addKeyValue("maxPendingAcquires", this.maxPendingAcquires); + } + + loggingEvent.log("NettyHttpClient was built with these configurations."); Bootstrap bootstrap = new Bootstrap().group(group) .channel(channelClass) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) getTimeoutMillis(connectTimeout, 10_000)); // Disable auto-read as we want to control when and how data is read from the channel. bootstrap.option(ChannelOption.AUTO_READ, false); + // Enable TCP keep-alive to proactively detect and clean up stale connections in the pool. This helps evict + // connections that have been silently dropped by network intermediaries. + bootstrap.option(ChannelOption.SO_KEEPALIVE, true); + // Allow the channel to remain open for writing even after the server has closed its sending side. + // This helps detect half-closures with a ChannelInputShutdownEvent in the PoolConnectionHealthHandler. + bootstrap.option(ChannelOption.ALLOW_HALF_CLOSURE, true); Configuration buildConfiguration = (configuration == null) ? Configuration.getGlobalConfiguration() : configuration; @@ -312,9 +416,17 @@ public HttpClient build() { ProxyOptions buildProxyOptions = (proxyOptions == null) ? ProxyOptions.fromConfiguration(buildConfiguration, true) : proxyOptions; - return new NettyHttpClient(bootstrap, sslContextModifier, maximumHttpVersion, - new ChannelInitializationProxyHandler(buildProxyOptions), getTimeoutMillis(readTimeout), - getTimeoutMillis(responseTimeout), getTimeoutMillis(writeTimeout)); + Netty4ConnectionPool connectionPool = null; + if (connectionPoolSize > 0) { + connectionPool + = new Netty4ConnectionPool(bootstrap, new ChannelInitializationProxyHandler(buildProxyOptions), + sslContextModifier, connectionPoolSize, connectionIdleTimeout, maxConnectionLifetime, + pendingAcquireTimeout, maxPendingAcquires, maximumHttpVersion); + } + + return new NettyHttpClient(bootstrap, group, connectionPool, buildProxyOptions, + new ChannelInitializationProxyHandler(buildProxyOptions), sslContextModifier, maximumHttpVersion, + getTimeoutMillis(readTimeout), getTimeoutMillis(responseTimeout), getTimeoutMillis(writeTimeout)); } static long getTimeoutMillis(Duration duration) { diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientProvider.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientProvider.java index 9151620ee25c..5541317fb0f9 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientProvider.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/NettyHttpClientProvider.java @@ -30,6 +30,20 @@ public HttpClient getHttpClient() { public NettyHttpClientProvider() { } + /** + * Creates a new {@link HttpClient} instance with a default, shared connection pool. + *

+ * For more advanced customization, such as disabling pooling entirely, use the {@link NettyHttpClientBuilder}. + *

+ * Example: Creating a client without a connection pool + *

{@code
+     * HttpClient client = new NettyHttpClientBuilder()
+     * .connectionPoolSize(0)
+     * .build();
+     * }
+ * + * @return A new {@link HttpClient} instance. + */ @Override public HttpClient getNewInstance() { return new NettyHttpClientBuilder().build(); diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/ChannelInitializationProxyHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/ChannelInitializationProxyHandler.java index 56b79d04d92c..9e0717730775 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/ChannelInitializationProxyHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/ChannelInitializationProxyHandler.java @@ -67,7 +67,7 @@ public boolean test(SocketAddress socketAddress) { InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; String hostString = inetSocketAddress.getHostString(); - return hostString != null && nonProxyHostsPattern.matcher(hostString).matches(); + return hostString != null && !nonProxyHostsPattern.matcher(hostString).matches(); } /** diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4AlpnHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4AlpnHandler.java index 6b066103ab6d..0a869b84f761 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4AlpnHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4AlpnHandler.java @@ -2,26 +2,22 @@ // Licensed under the MIT License. package io.clientcore.http.netty4.implementation; +import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpRequest; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; -import io.netty.handler.codec.http2.DefaultHttp2Connection; -import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener; -import io.netty.handler.codec.http2.Http2Connection; -import io.netty.handler.codec.http2.Http2FrameListener; -import io.netty.handler.codec.http2.Http2Settings; -import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandler; -import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder; -import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder; import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; +import io.netty.util.AttributeKey; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; -import static io.clientcore.http.netty4.implementation.Netty4Utility.createCodec; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.ALPN; +import static io.clientcore.http.netty4.implementation.Netty4Utility.configureHttpsPipeline; import static io.clientcore.http.netty4.implementation.Netty4Utility.sendHttp11Request; +import static io.clientcore.http.netty4.implementation.Netty4Utility.sendHttp2Request; import static io.clientcore.http.netty4.implementation.Netty4Utility.setOrSuppressError; /** @@ -29,7 +25,15 @@ * either HTTP/1.1 or HTTP/2 based on the result of negotiation. */ public final class Netty4AlpnHandler extends ApplicationProtocolNegotiationHandler { - private static final int TWO_FIFTY_SIX_KB = 256 * 1024; + + /** + * An Attribute key for the channel storing the HTTP protocol version that was negotiated. + * This information will be used in case the same channel is reused in the future, so we can + * adjust the correct handlers because there's no need for ALPN to run again. + */ + public static final AttributeKey HTTP_PROTOCOL_VERSION_KEY + = AttributeKey.valueOf("http-protocol-version"); + private final HttpRequest request; private final AtomicReference responseReference; private final AtomicReference errorReference; @@ -38,9 +42,9 @@ public final class Netty4AlpnHandler extends ApplicationProtocolNegotiationHandl /** * Creates a new instance of {@link Netty4AlpnHandler} with a fallback to using HTTP/1.1. * - * @param request The request to send once ALPN negotiation completes. + * @param request The request to send once ALPN negotiation completes. * @param errorReference An AtomicReference keeping track of errors during the request lifecycle. - * @param latch A CountDownLatch that will be released once the request completes. + * @param latch A CountDownLatch that will be released once the request completes. */ public Netty4AlpnHandler(HttpRequest request, AtomicReference responseReference, AtomicReference errorReference, CountDownLatch latch) { @@ -58,78 +62,37 @@ public boolean isSharable() { @Override protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + HttpProtocolVersion protocolVersion; if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { - // TODO (alzimmer): InboundHttp2ToHttpAdapter buffers the entire response into a FullHttpResponse. Need to - // create a streaming version of this to support huge response payloads. - Http2Connection http2Connection = new DefaultHttp2Connection(false); - Http2Settings settings = new Http2Settings().headerTableSize(4096) - .maxHeaderListSize(TWO_FIFTY_SIX_KB) - .pushEnabled(false) - .initialWindowSize(TWO_FIFTY_SIX_KB); - Http2FrameListener frameListener = new DelegatingDecompressorFrameListener(http2Connection, - new InboundHttp2ToHttpAdapterBuilder(http2Connection).maxContentLength(Integer.MAX_VALUE) - .propagateSettings(true) - .validateHttpHeaders(true) - .build(), - 0); + protocolVersion = HttpProtocolVersion.HTTP_2; + } else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { + protocolVersion = HttpProtocolVersion.HTTP_1_1; + } else { + ctx.fireExceptionCaught(new IllegalStateException("unknown protocol: " + protocol)); + return; + } - HttpToHttp2ConnectionHandler connectionHandler - = new HttpToHttp2ConnectionHandlerBuilder().initialSettings(settings) - .frameListener(frameListener) - .connection(http2Connection) - .validateHeaders(true) - .build(); + ctx.channel().attr(HTTP_PROTOCOL_VERSION_KEY).set(protocolVersion); - if (ctx.pipeline().get(Netty4HandlerNames.PROGRESS_AND_TIMEOUT) != null) { - ctx.pipeline() - .addAfter(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_RESPONSE, - new Netty4ResponseHandler(request, responseReference, errorReference, latch)); - ctx.pipeline() - .addBefore(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_CODEC, - connectionHandler); - } else { - ctx.pipeline().addAfter(Netty4HandlerNames.SSL, Netty4HandlerNames.HTTP_CODEC, connectionHandler); - ctx.pipeline() - .addAfter(Netty4HandlerNames.HTTP_CODEC, Netty4HandlerNames.HTTP_RESPONSE, - new Netty4ResponseHandler(request, responseReference, errorReference, latch)); - } + configureHttpsPipeline(ctx.pipeline(), request, protocolVersion, responseReference, errorReference, latch); + if (protocolVersion == HttpProtocolVersion.HTTP_2) { + sendHttp2Request(request, ctx.channel(), errorReference, latch); + } else { sendHttp11Request(request, ctx.channel(), errorReference) .addListener((ChannelFutureListener) sendListener -> { if (!sendListener.isSuccess()) { setOrSuppressError(errorReference, sendListener.cause()); - sendListener.channel().close(); + sendListener.channel().pipeline().fireExceptionCaught(sendListener.cause()); latch.countDown(); } else { sendListener.channel().read(); } }); - } else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { - if (ctx.pipeline().get(Netty4HandlerNames.PROGRESS_AND_TIMEOUT) != null) { - ctx.pipeline() - .addAfter(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_RESPONSE, - new Netty4ResponseHandler(request, responseReference, errorReference, latch)); - ctx.pipeline() - .addBefore(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_CODEC, createCodec()); - } else { - ctx.pipeline().addAfter(Netty4HandlerNames.SSL, Netty4HandlerNames.HTTP_CODEC, createCodec()); - ctx.pipeline() - .addAfter(Netty4HandlerNames.HTTP_CODEC, Netty4HandlerNames.HTTP_RESPONSE, - new Netty4ResponseHandler(request, responseReference, errorReference, latch)); - } + } - sendHttp11Request(request, ctx.channel(), errorReference) - .addListener((ChannelFutureListener) sendListener -> { - if (!sendListener.isSuccess()) { - setOrSuppressError(errorReference, sendListener.cause()); - sendListener.channel().close(); - latch.countDown(); - } else { - sendListener.channel().read(); - } - }); - } else { - throw new IllegalStateException("unknown protocol: " + protocol); + if (ctx.pipeline().get(ALPN) != null) { + ctx.pipeline().remove(this); } } } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelBinaryData.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelBinaryData.java index 837f4d164b1b..f1c8500dbb39 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelBinaryData.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelBinaryData.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import static io.clientcore.http.netty4.implementation.Netty4Utility.awaitLatch; @@ -36,6 +37,11 @@ public final class Netty4ChannelBinaryData extends BinaryData { private final Channel channel; private final Long length; private final boolean isHttp2; + private final AtomicBoolean streamDrained = new AtomicBoolean(false); + private final CountDownLatch drainLatch = new CountDownLatch(1); + // Manages the "closed" state, ensuring cleanup happens only once. + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Runnable onClose; // Non-final to allow nulling out after use. private ByteArrayOutputStream eagerContent; @@ -49,12 +55,23 @@ public final class Netty4ChannelBinaryData extends BinaryData { * @param channel The Netty {@link Channel}. * @param length Size of the response body (if known). * @param isHttp2 Flag indicating whether the handler is used for HTTP/2 or not. + * @param onClose The Runnable to run when the {@code close()} method is triggered. */ + public Netty4ChannelBinaryData(ByteArrayOutputStream eagerContent, Channel channel, Long length, boolean isHttp2, + Runnable onClose) { + this.eagerContent = eagerContent; + this.channel = channel; + this.length = length; + this.isHttp2 = isHttp2; + this.onClose = onClose; + } + public Netty4ChannelBinaryData(ByteArrayOutputStream eagerContent, Channel channel, Long length, boolean isHttp2) { this.eagerContent = eagerContent; this.channel = channel; this.length = length; this.isHttp2 = isHttp2; + this.onClose = null; } @Override @@ -64,27 +81,8 @@ public byte[] toBytes() { } if (bytes == null) { - CountDownLatch latch = new CountDownLatch(1); - Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, - buf -> buf.readBytes(eagerContent, buf.readableBytes()), isHttp2); - channel.pipeline().addLast(Netty4HandlerNames.EAGER_CONSUME, handler); - channel.config().setAutoRead(true); - - awaitLatch(latch); - - Throwable exception = handler.channelException(); - if (exception != null) { - if (exception instanceof Error) { - throw (Error) exception; - } else { - throw CoreException.from(exception); - } - } else { - bytes = eagerContent.toByteArray(); - } - eagerContent = null; + drainStream(); } - return bytes; } @@ -105,7 +103,7 @@ public T toObject(Type type, ObjectSerializer serializer) { @Override public InputStream toStream() { if (bytes == null) { - return new Netty4ChannelInputStream(eagerContent, channel, isHttp2); + return new Netty4ChannelInputStream(eagerContent, channel, isHttp2, this::close); } else { return new ByteArrayInputStream(bytes); } @@ -124,36 +122,51 @@ public void writeTo(JsonWriter jsonWriter) { @Override public void writeTo(OutputStream outputStream) { + Objects.requireNonNull(outputStream, "'outputStream' cannot be null."); + try { - if (bytes == null) { - // Channel hasn't been read yet, don't buffer it, just write it to the OutputStream as it's being read. - if (eagerContent.size() > 0) { - outputStream.write(eagerContent.toByteArray()); - } + if (bytes != null) { + outputStream.write(bytes); + return; + } - CountDownLatch latch = new CountDownLatch(1); - Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, - buf -> buf.readBytes(outputStream, buf.readableBytes()), isHttp2); - channel.pipeline().addLast(Netty4HandlerNames.EAGER_CONSUME, handler); - channel.config().setAutoRead(true); + if (streamDrained.compareAndSet(false, true)) { + try { + // Channel hasn't been read yet, don't buffer it, just write it to the OutputStream as it's being read. + if (eagerContent != null && eagerContent.size() > 0) { + eagerContent.writeTo(outputStream); + } - awaitLatch(latch); + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, + buf -> buf.readBytes(outputStream, buf.readableBytes()), isHttp2); + channel.pipeline().addLast(Netty4HandlerNames.EAGER_CONSUME, handler); + channel.config().setAutoRead(true); - Throwable exception = handler.channelException(); - if (exception != null) { - if (exception instanceof Error) { - throw (Error) exception; - } else { - throw CoreException.from(exception); + awaitLatch(latch); + + Throwable exception = handler.channelException(); + if (exception != null) { + if (exception instanceof Error) { + throw (Error) exception; + } else { + throw CoreException.from(exception); + } + } + } finally { + eagerContent = null; + drainLatch.countDown(); + + if (onClose != null) { + onClose.run(); } } - eagerContent = null; } else { - // Already converted the Channel to a byte[], use it. - outputStream.write(bytes); + throw LOGGER.throwableAtError() + .log("The stream has already been consumed and is not replayable.", IllegalStateException::new); } - } catch (IOException ex) { - throw LOGGER.throwableAtError().log(ex, CoreException::from); + } catch (IOException e) { + throw CoreException.from(e); } } @@ -182,10 +195,96 @@ public BinaryData toReplayableBinaryData() { return BinaryData.fromBytes(toBytes()); } + /** + * Ensures the underlying network stream is fully consumed but does not close the channel, + * allowing it to be reused by the connection pool. + */ @Override public void close() { - eagerContent = null; - channel.disconnect(); - channel.close(); + if (closed.compareAndSet(false, true)) { + // If draining hasn't started, it means the stream was not consumed. + // We need to drain it to ensure the connection can be safely reused. + if (!streamDrained.get()) { + drainAndCleanupAsync(); + } else { + if (onClose != null) { + onClose.run(); + } + } + } + } + + private void drainAndCleanupAsync() { + if (streamDrained.compareAndSet(false, true)) { + if (!channel.isActive()) { + if (onClose != null) { + onClose.run(); + } + drainLatch.countDown(); + return; + } + + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(() -> { + if (onClose != null) { + onClose.run(); + } + drainLatch.countDown(); + }, isHttp2); + + channel.pipeline().addLast(Netty4HandlerNames.EAGER_CONSUME, handler); + channel.config().setAutoRead(true); + } else { + awaitLatch(drainLatch); + } + } + + private void drainStream() { + if (streamDrained.compareAndSet(false, true)) { + try { + if (length != null && eagerContent != null && eagerContent.size() >= length) { + bytes = eagerContent.toByteArray(); + return; + } + + if (!channel.isActive()) { + // The connection was closed before we could read the full body. + // Check if, by chance, the eager content we already have satisfies the full length. + // This can happen if the body was very small and the server closed immediately. + if (length != null && eagerContent != null && length == eagerContent.size()) { + bytes = eagerContent.toByteArray(); + return; + } + throw LOGGER.throwableAtError() + .log("Connection closed prematurely while reading response body.", IOException::new); + } + + CountDownLatch ioLatch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(ioLatch, + buf -> buf.readBytes(eagerContent, buf.readableBytes()), isHttp2); + + channel.pipeline().addLast(Netty4HandlerNames.EAGER_CONSUME, handler); + channel.config().setAutoRead(true); + + awaitLatch(ioLatch); + Throwable exception = handler.channelException(); + + if (exception != null) { + if (exception instanceof Error) { + throw (Error) exception; + } else { + throw CoreException.from(exception); + } + } else { + bytes = eagerContent.toByteArray(); + } + } catch (IOException e) { + throw LOGGER.throwableAtError().log(e, CoreException::from); + } finally { + eagerContent = null; + drainLatch.countDown(); + } + } else { + awaitLatch(drainLatch); + } } } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelInputStream.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelInputStream.java index 418df7945248..de5d4c96c0af 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelInputStream.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ChannelInputStream.java @@ -7,7 +7,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; /** @@ -16,6 +17,7 @@ public final class Netty4ChannelInputStream extends InputStream { private final Channel channel; private final boolean isHttp2; + private final Runnable onClose; // Indicator for the Channel being fully read. // This will become true before 'streamDone' becomes true, but both may become true in the same operation. @@ -27,9 +29,9 @@ public final class Netty4ChannelInputStream extends InputStream { // Once this is true, the stream will never return data again. private boolean streamDone = false; - // Linked list of byte[]s that maintains the last available contents from the Channel / eager content. - // A list is needed as each Channel.read() may result in many channelRead calls. - private final LinkedList additionalBuffers; + // Queue of byte[]s that maintains the last available contents from the Channel / eager content. + // A queue is needed as each Channel.read() may result in many channelRead calls. + private final Queue additionalBuffers; private byte[] currentBuffer; @@ -46,20 +48,23 @@ public final class Netty4ChannelInputStream extends InputStream { * status line and response headers. * @param channel The {@link Channel} to read from. * @param isHttp2 Flag indicating whether the Channel is used for HTTP/2 or not. + * @param onClose A runnable to execute when the stream is closed. */ - Netty4ChannelInputStream(ByteArrayOutputStream eagerContent, Channel channel, boolean isHttp2) { + Netty4ChannelInputStream(ByteArrayOutputStream eagerContent, Channel channel, boolean isHttp2, Runnable onClose) { if (eagerContent != null && eagerContent.size() > 0) { this.currentBuffer = eagerContent.toByteArray(); + eagerContent.reset(); } else { this.currentBuffer = new byte[0]; } this.readIndex = 0; - this.additionalBuffers = new LinkedList<>(); + this.additionalBuffers = new ConcurrentLinkedQueue<>(); this.channel = channel; if (channel.pipeline().get(Netty4InitiateOneReadHandler.class) != null) { channel.pipeline().remove(Netty4InitiateOneReadHandler.class); } this.isHttp2 = isHttp2; + this.onClose = onClose; } byte[] getCurrentBuffer() { @@ -167,19 +172,28 @@ public long skip(long n) throws IOException { return n - toSkip; } + /** + * Closes this input stream and ensures the underlying connection can be returned to the pool. + * This method does not close the underlying channel. Instead, it triggers the onClose + * callback which is responsible for draining the rest of the stream content. + */ @Override - public void close() { - currentBuffer = null; - additionalBuffers.clear(); - if (channel.isOpen() || channel.isActive()) { - channel.disconnect(); - channel.close(); + public void close() throws IOException { + try { + if (onClose != null) { + onClose.run(); + } + } finally { + super.close(); + currentBuffer = null; + additionalBuffers.clear(); + streamDone = true; } } private boolean setupNextBuffer() throws IOException { if (!additionalBuffers.isEmpty()) { - currentBuffer = additionalBuffers.pop(); + currentBuffer = additionalBuffers.poll(); readIndex = 0; return true; } else if (readMore()) { @@ -214,7 +228,7 @@ private boolean readMore() throws IOException { byte[] buffer = new byte[byteBuf.readableBytes()]; byteBuf.readBytes(buffer); - additionalBuffers.add(buffer); + additionalBuffers.offer(buffer); }, isHttp2); channel.pipeline().addLast(Netty4HandlerNames.READ_ONE, handler); } @@ -242,7 +256,7 @@ private boolean readMore() throws IOException { } if (!additionalBuffers.isEmpty()) { - currentBuffer = additionalBuffers.pop(); + currentBuffer = additionalBuffers.poll(); readIndex = 0; } else if (channelDone) { // Don't listen to IntelliJ here, channelDone may be false. // This read contained no data and the channel completed, therefore the stream is also completed. diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPool.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPool.java new file mode 100644 index 000000000000..f5645652f5bb --- /dev/null +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPool.java @@ -0,0 +1,709 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.core.instrumentation.logging.ClientLogger; +import io.clientcore.core.models.CoreException; +import io.clientcore.core.utils.AuthenticateChallenge; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.handler.codec.http2.Http2GoAwayFrame; +import io.netty.handler.proxy.HttpProxyHandler; +import io.netty.handler.proxy.ProxyHandler; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; +import io.netty.handler.ssl.SslCloseCompletionEvent; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.ScheduledFuture; + +import javax.net.ssl.SSLException; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.CONNECTION_POOL_ALPN; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.POOL_CONNECTION_HEALTH; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROXY; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROXY_EXCEPTION_WARNING_SUPPRESSION; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.SSL; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.SSL_GRACEFUL_SHUTDOWN; +import static io.clientcore.http.netty4.implementation.Netty4Utility.buildSslContext; + +/** + * A thread-safe pool of Netty {@link Channel}s that are reused for requests to the same route. + *

+ * This connection pool manages the entire connection lifecycle, including TCP connected, proxy handshakes, and the + * asynchronous SSL/ALPN negotiation. It is designed to return fully configured channels that are + * ready for immediate use, thereby eliminating per-request handshake latency. + */ +public class Netty4ConnectionPool implements Closeable { + + /** + * An AttributeKey referring to a channel-specific {@link ReentrantLock}. + *

+ * This lock is used to ensure that the setup and cleanup of a channel's pipeline are atomic operations. + * It protects against race conditions where a channel might be acquired from the pool and configured for a new + * request before the cleanup from the previous request has fully completed. Each channel gets its own unique + * lock instance, making the lock contention extremely low. + */ + public static final AttributeKey CHANNEL_LOCK = AttributeKey.valueOf("channel-lock"); + + /** + * A unique token created for each request to identify the current owner of a channel pipeline. + *

+ * It protects against stale cleanup handlers from previous, timed-out, or failed requests, + * ensuring that only the {@link Netty4PipelineCleanupHandler} that belongs to the current, + * active request is allowed to modify the pipeline. + */ + public static final AttributeKey PIPELINE_OWNER_TOKEN = AttributeKey.valueOf("pipeline-owner-token"); + + private static final AttributeKey HTTP2_GOAWAY_RECEIVED = AttributeKey.valueOf("http2-goaway-received"); + private static final AttributeKey POOLED_CONNECTION_KEY + = AttributeKey.valueOf("pooled-connection-key"); + + private static final ClientLogger LOGGER = new ClientLogger(Netty4ConnectionPool.class); + private static final String CLOSED_POOL_ERROR_MESSAGE = "Connection pool has been closed."; + + private final ConcurrentMap pool = new ConcurrentHashMap<>(); + private final AtomicBoolean closed = new AtomicBoolean(false); + + private final Bootstrap bootstrap; + private final int maxConnectionsPerRoute; + private final long idleTimeoutNanos; + private final long maxLifetimeNanos; + private final Duration pendingAcquireTimeout; + private final int maxPendingAcquires; + private final Future cleanupTask; + + private final ChannelInitializationProxyHandler channelInitializationProxyHandler; + private final Consumer sslContextModifier; + private final AtomicReference> proxyChallenges; + private final HttpProtocolVersion maximumHttpVersion; + + @ChannelHandler.Sharable + public static class Http2GoAwayHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof Http2GoAwayFrame) { + // A GOAWAY frame was received. + // Mark the channel so the pool knows not to reuse it for new requests. + ctx.channel().attr(HTTP2_GOAWAY_RECEIVED).set(true); + } + super.channelRead(ctx, msg); + } + } + + @ChannelHandler.Sharable + public static final class SuppressProxyConnectExceptionWarningHandler extends ChannelInboundHandlerAdapter { + public static final SuppressProxyConnectExceptionWarningHandler INSTANCE + = new SuppressProxyConnectExceptionWarningHandler(); + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause instanceof HttpProxyHandler.HttpProxyConnectException) { + return; + } + ctx.fireExceptionCaught(cause); + } + } + + public static class SslGracefulShutdownHandler extends ChannelInboundHandlerAdapter { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslCloseCompletionEvent) { + ctx.channel().close(); + } + super.userEventTriggered(ctx, evt); + } + } + + @ChannelHandler.Sharable + public static class PoolConnectionHealthHandler extends ChannelInboundHandlerAdapter { + private static final PoolConnectionHealthHandler INSTANCE = new PoolConnectionHealthHandler(); + + private static final AttributeKey CONNECTION_INVALIDATED + = AttributeKey.valueOf("connection-invalidated"); + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + // This event signals the server has closed its sending side (TCP half-closure). + // The channel is now considered unusable and must be closed. + if (evt instanceof ChannelInputShutdownEvent) { + invalidateAndClose(ctx.channel()); + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // This is a fallback for when the connection is fully closed for any reason. + invalidateAndClose(ctx.channel()); + super.channelInactive(ctx); + } + + private void invalidateAndClose(Channel channel) { + // Mark the channel as invalid. The 'isHealthy' check uses this attribute + // to immediately reject the channel without waiting for it to be fully closed. + channel.attr(CONNECTION_INVALIDATED).set(true); + channel.close(); + } + } + + /** + * A specialized ALPN handler that signals when a channel is fully negotiated and ready for use by the pool. + */ + public static class ConnectionPoolAlpnHandler extends ApplicationProtocolNegotiationHandler { + private final Promise promise; + + ConnectionPoolAlpnHandler(Promise promise) { + super(ApplicationProtocolNames.HTTP_1_1); + this.promise = promise; + } + + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + HttpProtocolVersion protocolVersion; + if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { + protocolVersion = HttpProtocolVersion.HTTP_2; + } else { + protocolVersion = HttpProtocolVersion.HTTP_1_1; + } + + ctx.channel().attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).set(protocolVersion); + + // After setting the protocol, fulfill the promise to signal the channel is ready. + promise.setSuccess(ctx.channel()); + + ctx.pipeline().remove(this); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + // If an error happens during negotiation, fail the promise. + promise.setFailure(cause); + ctx.fireExceptionCaught(cause); + } + } + + public Netty4ConnectionPool(Bootstrap bootstrap, + ChannelInitializationProxyHandler channelInitializationProxyHandler, + Consumer sslContextModifier, int maxConnectionsPerRoute, Duration connectionIdleTimeout, + Duration maxConnectionLifetime, Duration pendingAcquireTimeout, int maxPendingAcquires, + HttpProtocolVersion maximumHttpVersion) { + this.bootstrap = bootstrap; + this.channelInitializationProxyHandler = channelInitializationProxyHandler; + this.sslContextModifier = sslContextModifier; + this.proxyChallenges = new AtomicReference<>(); + this.maxConnectionsPerRoute = maxConnectionsPerRoute; + this.idleTimeoutNanos = durationToNanos(connectionIdleTimeout); + this.maxLifetimeNanos = durationToNanos(maxConnectionLifetime); + this.pendingAcquireTimeout = pendingAcquireTimeout; + this.maxPendingAcquires = maxPendingAcquires; + this.maximumHttpVersion = maximumHttpVersion; + + if (this.idleTimeoutNanos > 0) { + EventLoopGroup eventLoopGroup = bootstrap.config().group(); + // This scheduled task cleans up idle connections periodically. + // The 30-second interval is a trade-off between precision and performance. + // Running it more frequently would be more precise but add more overhead. + // This means a connection may stay idle for up to (idleTimeout + 30s) before being closed, + // which is an acceptable behavior for preventing resource leaks. + this.cleanupTask + = eventLoopGroup.scheduleAtFixedRate(this::cleanupIdleConnections, 30, 30, TimeUnit.SECONDS); + } else { + this.cleanupTask = null; + } + } + + /** + * Acquires a channel for the given route from the pool. + *

+ * The returned {@link Future} will be notified with a channel that is fully connected, authenticated by any + * proxy, and has completed its SSL/ALPN handshake (for HTTPS). This method will first attempt to reuse an + * available idle channel. If none are available and the pool has not reached its maximum capacity, a new + * channel will be created. If the pool is at maximum capacity, the acquisition request will be queued. + * + * @param key The composite key representing the connection route. + * @param isHttps Flag indicating if the connection should be secured. + * @return A {@link Future} that will complete with a ready-to-use {@link Channel}, or a failed {@link Future} + * in case the connection pool has been closed. + */ + public Future acquire(Netty4ConnectionPoolKey key, boolean isHttps) { + if (closed.get()) { + return bootstrap.config() + .group() + .next() + .newFailedFuture(new IllegalStateException(CLOSED_POOL_ERROR_MESSAGE)); + } + + PerRoutePool perRoutePool = pool.computeIfAbsent(key, k -> new PerRoutePool(k, isHttps)); + return perRoutePool.acquire(); + } + + /** + * Releases a healthy channel back to the connection pool to be reused for future requests. + *

+ * The channel's pipeline must be cleaned of all request-specific handlers before being released. + * Unhealthy channels (e.g., those that are inactive or have received a GOAWAY frame) will be closed and discarded. + * + * @param channel The channel to release back to the connection pool. + */ + public void release(Channel channel) { + if (channel == null) { + return; + } + + PooledConnection pooledConnection = channel.attr(POOLED_CONNECTION_KEY).get(); + if (pooledConnection == null) { + channel.close(); + return; + } + + if (closed.get()) { + pooledConnection.close(); + return; + } + + PerRoutePool perRoutePool = pool.get(pooledConnection.key); + if (perRoutePool != null) { + perRoutePool.release(pooledConnection); + } else { + pooledConnection.close(); + } + } + + /** + * Periodically cleans up connections that have been idle for too long. + */ + private void cleanupIdleConnections() { + if (idleTimeoutNanos <= 0 || pool.isEmpty()) { + return; + } + + for (PerRoutePool perRoutePool : pool.values()) { + perRoutePool.cleanup(); + } + } + + @Override + public void close() throws IOException { + if (closed.compareAndSet(false, true)) { + if (cleanupTask != null) { + cleanupTask.cancel(false); + } + pool.values().forEach(PerRoutePool::close); + pool.clear(); + } + } + + public Bootstrap getBootstrap() { + return bootstrap; + } + + private static long durationToNanos(Duration duration) { + return (duration == null || duration.isNegative() || duration.isZero()) ? -1 : duration.toNanos(); + } + + /** + * A wrapper for a Netty Channel that holds pooling-related metadata. + */ + private static final class PooledConnection { + private final Channel channel; + private final Netty4ConnectionPoolKey key; + private final OffsetDateTime creationTime; + private volatile OffsetDateTime idleSince; + + PooledConnection(Channel channel, Netty4ConnectionPoolKey key) { + this.channel = channel; + this.key = key; + this.creationTime = OffsetDateTime.now(ZoneOffset.UTC); + channel.attr(POOLED_CONNECTION_KEY).set(this); + } + + private boolean isActiveAndWriteable() { + return channel.isActive() && channel.isWritable(); + } + + private void close() { + channel.close(); + } + } + + /** + * Manages connections and pending acquirers for a single route. + */ + class PerRoutePool { + private final Deque idleConnections = new ConcurrentLinkedDeque<>(); + private final Deque> pendingAcquirers = new ConcurrentLinkedDeque<>(); + // Counter for all connections for a specific route (active and idle). + private final AtomicInteger totalConnections = new AtomicInteger(0); + private final Netty4ConnectionPoolKey key; + private final SocketAddress route; + private final boolean isHttps; + + // A lock to protect the pool's internal state during acquire/release decisions. + private final ReentrantLock poolLock = new ReentrantLock(); + + PerRoutePool(Netty4ConnectionPoolKey key, boolean isHttps) { + this.key = key; + this.route = key.getConnectionTarget(); + this.isHttps = isHttps; + } + + /** + * Acquires a connection for this specific route, following the pool's logic flow: + *

    + *
  1. Attempt to poll a healthy, idle connection from the queue.
  2. + *
  3. If none is available, attempt to create a new connection if capacity allows.
  4. + *
  5. If at capacity, queue the acquisition request.
  6. + *
+ * + * @return A {@link Future} that completes with a {@link Channel}. + */ + private Future acquire() { + if (closed.get()) { + return bootstrap.config() + .group() + .next() + .newFailedFuture(new IllegalStateException(CLOSED_POOL_ERROR_MESSAGE)); + } + + poolLock.lock(); + try { + // First, check for an available idle connection. + PooledConnection connection = pollIdleAndCheckHealth(); + if (connection != null) { + // Found a valid, idle connection. Return it immediately. + return connection.channel.eventLoop().newSucceededFuture(connection.channel); + } + + // No idle connections. Check if we can create a new one. + if (totalConnections.get() < maxConnectionsPerRoute) { + // Increment count and create a new connection outside the lock. + totalConnections.getAndIncrement(); + return createNewConnection(); + } + + // The Pool is full. Queue the acquisition request. + if (pendingAcquirers.size() >= maxPendingAcquires) { + return bootstrap.config() + .group() + .next() + .newFailedFuture(CoreException.from("Pending acquisition queue is full.")); + } + + Promise promise = bootstrap.config().group().next().newPromise(); + + if (pendingAcquireTimeout != null && pendingAcquireTimeout.toMillis() > 0) { + final ScheduledFuture timeoutFuture = bootstrap.config().group().schedule(() -> { + if (promise.tryFailure( + CoreException.from("Connection acquisition timed out after " + pendingAcquireTimeout))) { + poolLock.lock(); + try { + pendingAcquirers.remove(promise); + } finally { + poolLock.unlock(); + } + } + }, pendingAcquireTimeout.toMillis(), TimeUnit.MILLISECONDS); + + promise.addListener(f -> { + if (f.isDone()) { + timeoutFuture.cancel(false); + } + }); + } + + promise.addListener(future -> { + if (future.isCancelled()) { + poolLock.lock(); + try { + pendingAcquirers.remove(promise); + } finally { + poolLock.unlock(); + } + } + }); + + pendingAcquirers.offer(promise); + return promise; + } finally { + poolLock.unlock(); + } + } + + /** + * Releases a connection back to this route's pool. + *

+ * First offers the connection to any pending acquirers before adding it to the idle queue. + * + * @param connection The connection to release. + */ + private void release(PooledConnection connection) { + poolLock.lock(); + try { + if (!isHealthy(connection)) { + // The connection is unhealthy. Close it. + // The asynchronous closeFuture listener ('handleConnectionClosure') will eventually + // decrement the connection count and create a new connections for available waiters. + connection.close(); + return; + } + + // The connection is healthy. Offer it to the waiters. + while (!pendingAcquirers.isEmpty()) { + Promise waiter = pollNextWaiter(); + if (waiter == null) { + // All remaining waiters were canceled. + break; + } + + if (waiter.trySuccess(connection.channel)) { + // The waiter accepted the connection, so we are done. + return; + } + // If the waiter didn't accept it (e.g., timed out), the loop will + // offer the connection to the next waiter. + } + + // There are no pending waiters, so add the healthy connection to the idle queue. + connection.idleSince = OffsetDateTime.now(ZoneOffset.UTC); + idleConnections.offer(connection); + } finally { + poolLock.unlock(); + } + } + + private void handleConnectionClosure() { + poolLock.lock(); + try { + totalConnections.getAndDecrement(); + + // A slot has opened up. Loop and create new connections + // for as long as there are waiters, and we have capacity. + while (totalConnections.get() < maxConnectionsPerRoute && !pendingAcquirers.isEmpty()) { + Promise waiter = pollNextWaiter(); + if (waiter == null) { + break; + } + + totalConnections.getAndIncrement(); + createNewConnection().addListener(future -> { + if (future.isSuccess()) { + // Try to give the new channel to the waiter. + // If it fails (e.g., the waiter timed out in the meantime), + // release the brand-new channel back to the pool. + if (!waiter.trySuccess((Channel) future.getNow())) { + release(((Channel) future.getNow()).attr(POOLED_CONNECTION_KEY).get()); + } + } else { + // The connection failed, so notify the waiter. + // The connection's own close handler will decrement totalConnections again. + waiter.tryFailure(future.cause()); + } + }); + } + } finally { + poolLock.unlock(); + } + } + + private Promise pollNextWaiter() { + while (!pendingAcquirers.isEmpty()) { + Promise waiter = pendingAcquirers.poll(); + if (!waiter.isCancelled()) { + return waiter; + } + } + return null; + } + + private PooledConnection pollIdleAndCheckHealth() { + while (!idleConnections.isEmpty()) { + PooledConnection connection = idleConnections.poll(); + if (isHealthy(connection)) { + connection.idleSince = null; // Mark as active + return connection; + } + connection.close(); + } + return null; + } + + /** + * Creates a new channel and asynchronously orchestrates its full setup. + *

+ * This method configures a {@link ChannelInitializer} to set up the base pipeline with handlers for health, + * proxying, and SSL. The readiness of the channel is signaled by a {@link Promise}, which is only fulfilled + * after all asynchronous setup stages (TCP connect, proxy handshake, and SSL/ALPN negotiation) are complete. + * + * @return A {@link Future} that will complete with a new, fully configured channel. + */ + private Future createNewConnection() { + Bootstrap newConnectionBootstrap = bootstrap.clone(); + Promise promise = newConnectionBootstrap.config().group().next().newPromise(); + newConnectionBootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(Channel channel) throws SSLException { + channel.attr(CHANNEL_LOCK).set(new ReentrantLock()); + + // Create the connection wrapper and attach it to the channel. + new PooledConnection(channel, key); + + ChannelPipeline pipeline = channel.pipeline(); + pipeline.addLast(POOL_CONNECTION_HEALTH, PoolConnectionHealthHandler.INSTANCE); + + // Test whether proxying should be applied to this Channel. If so, add it. + // Proxy detection MUST use the final destination address from the key. + boolean hasProxy = channelInitializationProxyHandler.test(key.getFinalDestination()); + if (hasProxy) { + ProxyHandler proxyHandler = channelInitializationProxyHandler.createProxy(proxyChallenges); + pipeline.addFirst(PROXY, proxyHandler); + pipeline.addAfter(PROXY, PROXY_EXCEPTION_WARNING_SUPPRESSION, + SuppressProxyConnectExceptionWarningHandler.INSTANCE); + } + + // Add SSL handling if the request is HTTPS. + if (isHttps) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) key.getFinalDestination(); + SslContext ssl = buildSslContext(maximumHttpVersion, sslContextModifier); + pipeline.addLast(SSL, ssl.newHandler(channel.alloc(), inetSocketAddress.getHostString(), + inetSocketAddress.getPort())); + pipeline.addAfter(SSL, SSL_GRACEFUL_SHUTDOWN, new SslGracefulShutdownHandler()); + pipeline.addLast(CONNECTION_POOL_ALPN, new ConnectionPoolAlpnHandler(promise)); + } else { + channel.attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).set(HttpProtocolVersion.HTTP_1_1); + } + } + }); + + newConnectionBootstrap.connect(route).addListener(future -> { + if (!future.isSuccess()) { + LOGGER.atError().setThrowable(future.cause()).log("Failed to connect to the route."); + handleConnectionClosure(); + promise.setFailure(future.cause()); + return; + } + + Channel newChannel = ((ChannelFuture) future).channel(); + newChannel.closeFuture().addListener(closeFuture -> handleConnectionClosure()); + + Runnable connectionSuccessRunner = () -> { + if (!isHttps) { + promise.trySuccess(newChannel); + } + // If it IS https, we do nothing. The ConnectionPoolAlpnHandler is in charge. + }; + + ProxyHandler proxyHandler = (ProxyHandler) newChannel.pipeline().get(PROXY); + if (proxyHandler != null) { + proxyHandler.connectFuture().addListener(proxyFuture -> { + if (proxyFuture.isSuccess()) { + connectionSuccessRunner.run(); + } else { + promise.tryFailure(proxyFuture.cause()); + newChannel.close(); + } + }); + } else { + connectionSuccessRunner.run(); + } + }); + return promise; + } + + private boolean isHealthy(PooledConnection connection) { + Channel channel = connection.channel; + + if (Boolean.TRUE.equals(channel.attr(PoolConnectionHealthHandler.CONNECTION_INVALIDATED).get())) { + return false; + } + + if (!connection.isActiveAndWriteable() || channel.config().isAutoRead()) { + return false; + } + + OffsetDateTime now = null; // To be initialized only if needed. + + if (maxLifetimeNanos > 0) { + now = OffsetDateTime.now(ZoneOffset.UTC); + if (Duration.between(connection.creationTime, now).toNanos() >= maxLifetimeNanos) { + return false; + } + } + + if (connection.idleSince != null && idleTimeoutNanos > 0) { + if (now == null) { + now = OffsetDateTime.now(ZoneOffset.UTC); + } + if (Duration.between(connection.idleSince, now).toNanos() >= idleTimeoutNanos) { + return false; + } + } + + HttpProtocolVersion protocol = channel.attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).get(); + if (protocol == HttpProtocolVersion.HTTP_2) { + return !Boolean.TRUE.equals(channel.attr(HTTP2_GOAWAY_RECEIVED).get()); + } + + return true; + } + + private void cleanup() { + if (idleConnections.isEmpty()) { + return; + } + + OffsetDateTime now = OffsetDateTime.now(ZoneOffset.UTC); + for (Iterator it = idleConnections.iterator(); it.hasNext();) { + PooledConnection connection = it.next(); + if (connection.idleSince != null + && Duration.between(connection.idleSince, now).toNanos() >= idleTimeoutNanos) { + it.remove(); + connection.close(); + } + } + } + + private void close() { + PooledConnection connection; + while ((connection = idleConnections.poll()) != null) { + connection.close(); + } + Promise waiter; + while ((waiter = pendingAcquirers.poll()) != null) { + waiter.tryFailure(new IOException(CLOSED_POOL_ERROR_MESSAGE)); + } + } + } +} diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolKey.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolKey.java new file mode 100644 index 000000000000..1280ebcd7910 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolKey.java @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import java.net.SocketAddress; +import java.util.Objects; + +/** + * A composite key for the connection pool. + *

+ * For direct connections, connectionTarget and finalDestination are the same. + * For proxied connections, connectionTarget is the proxy's address. For plain HTTP through a proxy, + * finalDestination is also the proxy's address to allow connection reuse. For HTTPS through a proxy, + * finalDestination is the target server's address to create a dedicated pool for the tunnel. + */ +public final class Netty4ConnectionPoolKey { + private final SocketAddress connectionTarget; + private final SocketAddress finalDestination; + + public Netty4ConnectionPoolKey(SocketAddress connectionTarget, SocketAddress finalDestination) { + this.connectionTarget = connectionTarget; + this.finalDestination = finalDestination; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Netty4ConnectionPoolKey poolKey = (Netty4ConnectionPoolKey) o; + return Objects.equals(connectionTarget, poolKey.connectionTarget) + && Objects.equals(finalDestination, poolKey.finalDestination); + } + + @Override + public int hashCode() { + return Objects.hash(connectionTarget, finalDestination); + } + + public SocketAddress getConnectionTarget() { + return this.connectionTarget; + } + + public SocketAddress getFinalDestination() { + return this.finalDestination; + } +} diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandler.java index 74cf9911a145..061e66b65dee 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandler.java @@ -13,6 +13,7 @@ import io.netty.util.ReferenceCountUtil; import java.io.IOException; +import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; @@ -23,6 +24,7 @@ public final class Netty4EagerConsumeChannelHandler extends ChannelInboundHandlerAdapter { private final CountDownLatch latch; private final IOExceptionCheckedConsumer byteBufConsumer; + private final Runnable onComplete; private final boolean isHttp2; private boolean lastRead; @@ -40,49 +42,64 @@ public Netty4EagerConsumeChannelHandler(CountDownLatch latch, IOExceptionChecked this.latch = latch; this.byteBufConsumer = byteBufConsumer; this.isHttp2 = isHttp2; + this.onComplete = null; + } + + /** + * Creates a new instance of {@link Netty4EagerConsumeChannelHandler} for non-blocking drain operations. + * + * @param onComplete The callback to run when the stream is fully drained or an error occurs. + * @param isHttp2 Flag indicating whether the handler is used for HTTP/2 or not. + */ + public Netty4EagerConsumeChannelHandler(Runnable onComplete, boolean isHttp2) { + this.latch = null; + this.byteBufConsumer = buf -> { + }; + this.onComplete = onComplete; + this.isHttp2 = isHttp2; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { - ByteBuf buf = null; - if (msg instanceof ByteBufHolder) { - buf = ((ByteBufHolder) msg).content(); - } else if (msg instanceof ByteBuf) { - buf = (ByteBuf) msg; - } + try { + if (byteBufConsumer != null) { + ByteBuf buf = null; + + if (msg instanceof ByteBufHolder) { + buf = ((ByteBufHolder) msg).content(); + } else if (msg instanceof ByteBuf) { + buf = (ByteBuf) msg; + } - if (buf != null && buf.isReadable()) { - try { - byteBufConsumer.accept(buf); - } catch (IOException | RuntimeException ex) { - ReferenceCountUtil.release(buf); - ctx.close(); - return; + if (buf != null && buf.isReadable()) { + byteBufConsumer.accept(buf); + } } - } - if (isHttp2) { - lastRead = msg instanceof Http2DataFrame && ((Http2DataFrame) msg).isEndStream(); - } else { - lastRead = msg instanceof LastHttpContent; + if (isHttp2) { + lastRead = msg instanceof Http2DataFrame && ((Http2DataFrame) msg).isEndStream(); + } else { + lastRead = msg instanceof LastHttpContent; + } + } catch (IOException | RuntimeException ex) { + exceptionCaught(ctx, ex); + } finally { + ReferenceCountUtil.release(msg); } - ctx.fireChannelRead(msg); } @Override public void channelReadComplete(ChannelHandlerContext ctx) { ctx.fireChannelReadComplete(); if (lastRead) { - latch.countDown(); - ctx.close(); + signalComplete(ctx); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { this.exception = cause; - latch.countDown(); - ctx.close(); + signalComplete(ctx); } Throwable channelException() { @@ -92,13 +109,39 @@ Throwable channelException() { // TODO (alzimmer): Are the latch countdowns needed for unregistering and inactivity? @Override public void channelUnregistered(ChannelHandlerContext ctx) { - latch.countDown(); + signalComplete(ctx); ctx.fireChannelUnregistered(); } @Override public void channelInactive(ChannelHandlerContext ctx) { - latch.countDown(); + signalComplete(ctx); ctx.fireChannelInactive(); } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + if (!ctx.channel().isActive()) { + // In case the read handler is added to a closed channel, we fail loudly by firing + // an exception. Simply counting down the latch would cause the caller to receive + // an empty/incomplete data stream without any sign of the underlying network error. + ctx.fireExceptionCaught(new ClosedChannelException()); + } + } + + private void signalComplete(ChannelHandlerContext ctx) { + if (ctx.pipeline().get(Netty4EagerConsumeChannelHandler.class) != null) { + ctx.pipeline().remove(this); + } + + // If in sync mode (for toBytes()), just signal completion. + if (latch != null) { + latch.countDown(); + } + + // If in async mode (for close()), run the cleanup callback. + if (onComplete != null) { + onComplete.run(); + } + } } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HandlerNames.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HandlerNames.java index eee40c7c4fe7..94983260bc9d 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HandlerNames.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HandlerNames.java @@ -39,6 +39,11 @@ public final class Netty4HandlerNames { */ public static final String ALPN = "clientcore.alpn"; + /** + * Name for the {@link Netty4ConnectionPool.ConnectionPoolAlpnHandler}. + */ + public static final String CONNECTION_POOL_ALPN = "clientcore.connectionpoolalpn"; + /** * Name for the HTTP/1.1 {@link HttpClientCodec} */ @@ -69,6 +74,31 @@ public final class Netty4HandlerNames { */ public static final String READ_ONE = "clientcore.readone"; + /** + * Name for the {@link Netty4PipelineCleanupHandler} + */ + public static final String PIPELINE_CLEANUP = "clientcore.pipelinecleanup"; + + /** + * Name for the {@link Netty4ConnectionPool.Http2GoAwayHandler} + */ + public static final String HTTP2_GOAWAY = "clientcore.http2goaway"; + + /** + * Name for the {@link Netty4ConnectionPool.SslGracefulShutdownHandler} + */ + public static final String SSL_GRACEFUL_SHUTDOWN = "clientcore.sslgracefulshutdown"; + + /** + * Name for the {@link Netty4ConnectionPool.PoolConnectionHealthHandler} + */ + public static final String POOL_CONNECTION_HEALTH = "clientcore.poolconnectionhealth"; + + /** + * Name for the {@link Netty4ConnectionPool.SuppressProxyConnectExceptionWarningHandler} + */ + public static final String PROXY_EXCEPTION_WARNING_SUPPRESSION = "clientcore.suppressproxyexceptionwarning"; + private Netty4HandlerNames() { } } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HttpProxyHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HttpProxyHandler.java index 5b8883b98261..9907983c8bb6 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HttpProxyHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4HttpProxyHandler.java @@ -101,7 +101,7 @@ public String authScheme() { @Override protected void addCodec(ChannelHandlerContext ctx) { - // TODO (alzimmer): Need to support HTTP/2 proxying. Check if Netty itself even supports this. + // TODO (alzimmer): Need to support HTTP/2 proxying. Check (issue 12088) if Netty itself even supports this. ctx.pipeline().addBefore(ctx.name(), Netty4HandlerNames.PROXY_CODEC, this.wrapper); } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4InitiateOneReadHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4InitiateOneReadHandler.java index 6a242ed03246..031b8ea459ab 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4InitiateOneReadHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4InitiateOneReadHandler.java @@ -13,6 +13,7 @@ import io.netty.util.ReferenceCountUtil; import java.io.IOException; +import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; /** @@ -78,7 +79,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { byteBufConsumer.accept(buf); } catch (IOException | RuntimeException ex) { ReferenceCountUtil.release(buf); - ctx.close(); + exceptionCaught(ctx, ex); return; } } @@ -95,9 +96,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { public void channelReadComplete(ChannelHandlerContext ctx) { latch.countDown(); if (lastRead) { - ctx.pipeline().remove(this); - ctx.close(); + if (ctx.pipeline().get(Netty4InitiateOneReadHandler.class) != null) { + ctx.pipeline().remove(this); + } } + ctx.fireChannelReadComplete(); } boolean isChannelConsumed() { @@ -108,7 +111,7 @@ boolean isChannelConsumed() { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { this.exception = cause; latch.countDown(); - ctx.close(); + ctx.fireExceptionCaught(cause); } Throwable channelException() { @@ -118,15 +121,31 @@ Throwable channelException() { // TODO (alzimmer): Are the latch countdowns needed for unregistering and inactivity? @Override public void channelUnregistered(ChannelHandlerContext ctx) { - latch.countDown(); + signalComplete(ctx); ctx.fireChannelUnregistered(); - ctx.pipeline().remove(this); } @Override public void channelInactive(ChannelHandlerContext ctx) { - latch.countDown(); + signalComplete(ctx); ctx.fireChannelInactive(); - ctx.pipeline().remove(this); } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + if (!ctx.channel().isActive()) { + // In case the read handler is added to a closed channel, we fail loudly by firing + // an exception. Simply counting down the latch would cause the caller to receive + // an empty/incomplete data stream without any sign of the underlying network error. + ctx.fireExceptionCaught(new ClosedChannelException()); + } + } + + private void signalComplete(ChannelHandlerContext ctx) { + latch.countDown(); + if (ctx.pipeline().get(Netty4InitiateOneReadHandler.class) != null) { + ctx.pipeline().remove(this); + } + } + } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupEvent.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupEvent.java new file mode 100644 index 000000000000..3fc5e1806d9e --- /dev/null +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupEvent.java @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +public enum Netty4PipelineCleanupEvent { + + /** + * Event used to indicate that the Netty channel will be released back to the connection pool. + */ + CLEANUP_PIPELINE +} diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandler.java new file mode 100644 index 000000000000..ece4bd9e77c0 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandler.java @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.ALPN; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.CHUNKED_WRITER; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.EAGER_CONSUME; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_CODEC; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_RESPONSE; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROGRESS_AND_TIMEOUT; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.READ_ONE; +import static io.clientcore.http.netty4.implementation.Netty4Utility.setOrSuppressError; + +/** + * A handler that cleans up the pipeline after a request-response cycle and releases + * the channel back to the connection pool. + */ +public class Netty4PipelineCleanupHandler extends ChannelDuplexHandler { + + private final Netty4ConnectionPool connectionPool; + private final AtomicReference errorReference; + private final AtomicBoolean cleanedUp = new AtomicBoolean(false); + private final Object pipelineOwnerToken; + + private static final List HANDLERS_TO_REMOVE; + + static { + List handlers = new ArrayList<>(); + handlers.add(PROGRESS_AND_TIMEOUT); + handlers.add(HTTP_RESPONSE); + handlers.add(HTTP_CODEC); + handlers.add(ALPN); + handlers.add(CHUNKED_WRITER); + handlers.add(EAGER_CONSUME); + handlers.add(READ_ONE); + HANDLERS_TO_REMOVE = Collections.unmodifiableList(handlers); + } + + public Netty4PipelineCleanupHandler(Netty4ConnectionPool connectionPool, AtomicReference errorReference, + Object pipelineOwnerToken) { + this.connectionPool = connectionPool; + this.errorReference = errorReference; + this.pipelineOwnerToken = pipelineOwnerToken; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + setOrSuppressError(errorReference, cause); + cleanup(ctx, true); + ctx.fireExceptionCaught(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + cleanup(ctx, true); + ctx.fireChannelInactive(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof Netty4PipelineCleanupEvent) { + cleanup(ctx, false); + return; + } + ctx.fireUserEventTriggered(evt); + } + + public void cleanup(ChannelHandlerContext ctx, boolean closeChannel) { + // Check if this handler is still the rightful owner of the pipeline. + // If the tokens don't match, a new request has taken over this channel, + // so this stale cleanup handler must not do anything. + if (ctx.channel().attr(Netty4ConnectionPool.PIPELINE_OWNER_TOKEN).get() != pipelineOwnerToken) { + return; + } + + if (!cleanedUp.compareAndSet(false, true)) { + return; + } + + ReentrantLock lock = ctx.channel().attr(Netty4ConnectionPool.CHANNEL_LOCK).get(); + lock.lock(); + + try { + // Always reset autoRead to false before returning a channel to the pool + // to ensure predictable behavior for the next request. + ctx.channel().config().setAutoRead(false); + + ChannelPipeline pipeline = ctx.channel().pipeline(); + + HttpProtocolVersion protocolVersion = ctx.channel().attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).get(); + boolean isHttp2 = protocolVersion == HttpProtocolVersion.HTTP_2; + + for (String handlerName : HANDLERS_TO_REMOVE) { + if (isHttp2 && HTTP_CODEC.equals(handlerName)) { + continue; + } + + if (pipeline.get(handlerName) != null) { + pipeline.remove(handlerName); + } + } + + if (pipeline.get(Netty4PipelineCleanupHandler.class) != null) { + pipeline.remove(this); + } + } finally { + lock.unlock(); + } + + if (closeChannel || !ctx.channel().isActive() || connectionPool == null) { + ctx.channel().close(); + } else { + connectionPool.release(ctx.channel()); + } + } +} diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ProgressAndTimeoutHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ProgressAndTimeoutHandler.java index 90e07fea13d6..3df4ba10e7b0 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ProgressAndTimeoutHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ProgressAndTimeoutHandler.java @@ -158,16 +158,17 @@ void writeTimeoutRunnable(ChannelHandlerContext ctx, boolean trackingWriteTimeou // No progress has been made since the last timeout event, channel has timed out. if (!closed) { disposeWriteTimeoutWatcher(); + // Fire the exception up the pipeline. The PipelineCleanupHandler will catch this + // and release the channel. We do not close the channel here. ctx.fireExceptionCaught(new TimeoutException( "Channel write operation timed out after " + writeTimeoutMillis + " milliseconds.")); - ctx.close(); closed = true; } } private void disposeWriteTimeoutWatcher() { trackingWriteTimeout = false; - if (writeTimeoutWatcher != null && !writeTimeoutWatcher.isDone()) { + if (writeTimeoutWatcher != null) { writeTimeoutWatcher.cancel(false); writeTimeoutWatcher = null; } @@ -204,14 +205,13 @@ void responseTimedOut(ChannelHandlerContext ctx, boolean trackingResponseTimeout disposeResponseTimeoutWatcher(); ctx.fireExceptionCaught( new TimeoutException("Channel response timed out after " + responseTimeoutMillis + " milliseconds.")); - ctx.close(); closed = true; } } private void disposeResponseTimeoutWatcher() { trackingResponseTimeout = false; - if (responseTimeoutWatcher != null && !responseTimeoutWatcher.isDone()) { + if (responseTimeoutWatcher != null) { responseTimeoutWatcher.cancel(false); responseTimeoutWatcher = null; } @@ -277,14 +277,13 @@ void readTimeoutRunnable(ChannelHandlerContext ctx, boolean trackingReadTimeout) disposeReadTimeoutWatcher(); ctx.fireExceptionCaught( new TimeoutException("Channel read timed out after " + readTimeoutMillis + " milliseconds.")); - ctx.close(); closed = true; } } private void disposeReadTimeoutWatcher() { trackingReadTimeout = false; - if (readTimeoutWatcher != null && !readTimeoutWatcher.isDone()) { + if (readTimeoutWatcher != null) { readTimeoutWatcher.cancel(false); readTimeoutWatcher = null; } diff --git a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandler.java b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandler.java index 5da09d7c8924..776b5b778642 100644 --- a/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandler.java +++ b/sdk/clientcore/http-netty4/src/main/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandler.java @@ -2,10 +2,10 @@ // Licensed under the MIT License. package io.clientcore.http.netty4.implementation; +import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpHeaders; import io.clientcore.core.http.models.HttpRequest; import io.clientcore.core.http.models.Response; -import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -15,8 +15,10 @@ import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.ReferenceCountUtil; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; @@ -47,6 +49,7 @@ public final class Netty4ResponseHandler extends ChannelInboundHandlerAdapter { // and initial response body content. private final ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); private boolean complete; + private boolean isHttp2; /** * Creates an instance of {@link Netty4ResponseHandler}. @@ -71,6 +74,12 @@ public Netty4ResponseHandler(HttpRequest request, AtomicReference OPTIONAL_NETTY_VERSION_ARTIFACTS = Arrays .asList("netty-transport-native-unix-common", "netty-transport-native-epoll", "netty-transport-native-kqueue"); + private static final int TWO_FIFTY_SIX_KB = 256 * 1024; + /** * Converts Netty HttpHeaders to ClientCore HttpHeaders. *

@@ -112,6 +135,10 @@ public static void awaitLatch(CountDownLatch latch) { *

* Content will only be written to the {@link OutputStream} if the {@link ByteBuf} is non-null and is * {@link ByteBuf#isReadable()}. The entire {@link ByteBuf} will be consumed. + *

+ *

Warning: This is a helper method and does NOT release the {@link ByteBuf} + * after it is consumed, and it must be manually released to avoid memory leaks (either the {@link ByteBuf} + * or the container holding the {@link ByteBuf}). * * @param byteBuf The Netty {@link ByteBuf} to read from. * @param stream The {@link OutputStream} to write to. @@ -125,10 +152,6 @@ static void readByteBufIntoOutputStream(ByteBuf byteBuf, OutputStream stream) th } byteBuf.readBytes(stream, byteBuf.readableBytes()); - if (byteBuf.refCnt() > 0) { - // Release the ByteBuf as we've consumed it. - byteBuf.release(); - } } /** @@ -438,6 +461,143 @@ public static HttpHeaderName fromPossibleAsciiString(CharSequence asciiString) { } } + /** + * Configures the pipeline for either HTTP/1.1 or HTTP/2 based on the negotiated protocol. + *

+ * This method adds the appropriate {@link Netty4HandlerNames#HTTP_CODEC} and + * {@link Netty4HandlerNames#HTTP_RESPONSE} handlers to the pipeline, positioned correctly + * relatively to the {@link Netty4HandlerNames#PROGRESS_AND_TIMEOUT} or {@link Netty4HandlerNames#SSL} handlers. + * + * @param pipeline The channel pipeline to configure. + * @param request The HTTP request. + * @param protocol The negotiated HTTP protocol version. + * @param responseReference The atomic reference to hold the response state. + * @param errorReference The atomic reference to hold any errors. + * @param latch The countdown latch to signal completion. + */ + public static void configureHttpsPipeline(ChannelPipeline pipeline, HttpRequest request, + HttpProtocolVersion protocol, AtomicReference responseReference, + AtomicReference errorReference, CountDownLatch latch) { + final ChannelHandler httpCodec; + if (HttpProtocolVersion.HTTP_2 == protocol) { + httpCodec = createHttp2Codec(); + } else { // HTTP/1.1 + httpCodec = createCodec(); + } + + Netty4ResponseHandler responseHandler + = new Netty4ResponseHandler(request, responseReference, errorReference, latch); + + if (pipeline.get(Netty4HandlerNames.PROGRESS_AND_TIMEOUT) != null) { + pipeline.addAfter(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_RESPONSE, + responseHandler); + pipeline.addBefore(Netty4HandlerNames.PROGRESS_AND_TIMEOUT, Netty4HandlerNames.HTTP_CODEC, httpCodec); + } else { + pipeline.addAfter(Netty4HandlerNames.SSL, Netty4HandlerNames.HTTP_CODEC, httpCodec); + pipeline.addAfter(Netty4HandlerNames.HTTP_CODEC, Netty4HandlerNames.HTTP_RESPONSE, responseHandler); + } + } + + public static ChannelHandler createHttp2Codec() { + // TODO (alzimmer): InboundHttp2ToHttpAdapter buffers the entire response into a FullHttpResponse. Need to + // create a streaming version of this to support huge response payloads. + Http2Connection http2Connection = new DefaultHttp2Connection(false); + Http2Settings settings = new Http2Settings().headerTableSize(4096) + .maxHeaderListSize(TWO_FIFTY_SIX_KB) + .pushEnabled(false) + .initialWindowSize(TWO_FIFTY_SIX_KB); + Http2FrameListener frameListener = new DelegatingDecompressorFrameListener(http2Connection, + new InboundHttp2ToHttpAdapterBuilder(http2Connection).maxContentLength(Integer.MAX_VALUE) + .propagateSettings(true) + .validateHttpHeaders(true) + .build(), + 0); + + return new HttpToHttp2ConnectionHandlerBuilder().initialSettings(settings) + .frameListener(frameListener) + .connection(http2Connection) + .validateHeaders(true) + .build(); + } + + public static void sendHttp2Request(HttpRequest request, Channel channel, AtomicReference errorReference, + CountDownLatch latch) { + io.netty.handler.codec.http.HttpRequest nettyRequest = toNettyHttpRequest(request); + + final ChannelFuture writeFuture; + + if (nettyRequest instanceof FullHttpRequest) { + writeFuture = channel.writeAndFlush(nettyRequest); + } else { + channel.write(nettyRequest); + + BinaryData requestBody = request.getBody(); + ChunkedInput chunkedInput = new HttpChunkedInput(new ChunkedStream(requestBody.toStream())); + + writeFuture = channel.writeAndFlush(chunkedInput); + } + + writeFuture.addListener(future -> { + if (future.isSuccess()) { + channel.read(); + } else { + setOrSuppressError(errorReference, future.cause()); + latch.countDown(); + } + }); + } + + private static io.netty.handler.codec.http.HttpRequest toNettyHttpRequest(HttpRequest request) { + HttpMethod nettyMethod = HttpMethod.valueOf(request.getHttpMethod().toString()); + String uri = request.getUri().toString(); + WrappedHttp11Headers nettyHeaders = new WrappedHttp11Headers(request.getHeaders()); + nettyHeaders.getCoreHeaders().set(HttpHeaderName.HOST, request.getUri().getHost()); + + BinaryData body = request.getBody(); + if (body == null || body.getLength() == 0 || body.isReplayable()) { + ByteBuf bodyBytes = (body == null || body.getLength() == 0) + ? Unpooled.EMPTY_BUFFER + : Unpooled.wrappedBuffer(body.toBytes()); + + nettyHeaders.getCoreHeaders().set(HttpHeaderName.CONTENT_LENGTH, String.valueOf(bodyBytes.readableBytes())); + return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, nettyMethod, uri, bodyBytes, nettyHeaders, + trailersFactory().newHeaders()); + } else { + return new DefaultHttpRequest(HttpVersion.HTTP_1_1, nettyMethod, uri, nettyHeaders); + } + } + + public static SslContext buildSslContext(HttpProtocolVersion maximumHttpVersion, + Consumer sslContextModifier) throws SSLException { + SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().endpointIdentificationAlgorithm("HTTPS"); + if (maximumHttpVersion == HttpProtocolVersion.HTTP_2) { + // If HTTP/2 is the maximum version, we need to ensure that ALPN is enabled. + SslProvider sslProvider = SslContext.defaultClientProvider(); + ApplicationProtocolConfig.SelectorFailureBehavior selectorBehavior; + ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedBehavior; + if (sslProvider == SslProvider.JDK) { + selectorBehavior = ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT; + selectedBehavior = ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT; + } else { + // Netty OpenSslContext doesn't support FATAL_ALERT, use NO_ADVERTISE and ACCEPT + // instead. + selectorBehavior = ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE; + selectedBehavior = ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT; + } + + sslContextBuilder.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .applicationProtocolConfig( + new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.ALPN, selectorBehavior, + selectedBehavior, ApplicationProtocolNames.HTTP_2, ApplicationProtocolNames.HTTP_1_1)); + } + if (sslContextModifier != null) { + // Allow the caller to modify the SslContextBuilder before it is built. + sslContextModifier.accept(sslContextBuilder); + } + + return sslContextBuilder.build(); + } + private Netty4Utility() { } } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttp2HttpClientTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttp2HttpClientTests.java index 9fb189ccd741..dba6aa10bbea 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttp2HttpClientTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttp2HttpClientTests.java @@ -31,7 +31,7 @@ public class NettyHttp2HttpClientTests extends HttpClientTests { private static final HttpClient HTTP_CLIENT_INSTANCE; static { - HTTP_CLIENT_INSTANCE = new NettyHttpClientBuilder() + HTTP_CLIENT_INSTANCE = new NettyHttpClientBuilder().connectionPoolSize(0) .sslContextModifier( builder -> builder.trustManager(new InsecureTrustManager()).secureRandom(new SecureRandom())) .maximumHttpVersion(HttpProtocolVersion.HTTP_2) @@ -47,6 +47,9 @@ public static void startTestServer() { @AfterAll public static void stopTestServer() { + if (HTTP_CLIENT_INSTANCE instanceof NettyHttpClient) { + ((NettyHttpClient) HTTP_CLIENT_INSTANCE).close(); + } if (server != null) { server.stop(); } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientBuilderTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientBuilderTests.java index 1c18c9618f5a..a417ef32f253 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientBuilderTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientBuilderTests.java @@ -3,6 +3,7 @@ package io.clientcore.http.netty4; +import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpMethod; import io.clientcore.core.http.models.HttpRequest; import io.clientcore.core.http.models.ProxyOptions; @@ -32,6 +33,7 @@ import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.time.Duration; import java.util.ArrayList; @@ -47,6 +49,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * Tests {@link NettyHttpClientBuilder}. @@ -372,6 +375,35 @@ public void getEventLoopGroupToUse(Class expected, EventLoopGroup configuredG assertInstanceOf(expected, eventLoopGroup); } + @Test + public void buildNettyClientWithoutConnectionPool() throws NoSuchFieldException, IllegalAccessException { + NettyHttpClient client = (NettyHttpClient) new NettyHttpClientBuilder().connectionPoolSize(0).build(); + + Field connectionPoolField = NettyHttpClient.class.getDeclaredField("connectionPool"); + connectionPoolField.setAccessible(true); + assertNull(connectionPoolField.get(client), "Connection pool should be null when pool size is 0."); + } + + @Test + public void testInvalidMaxPendingAcquires() { + NettyHttpClientBuilder builder = new NettyHttpClientBuilder(); + assertThrows(IllegalArgumentException.class, () -> builder.maxPendingAcquires(0)); + assertThrows(IllegalArgumentException.class, () -> builder.maxPendingAcquires(-1)); + } + + @Test + public void testMaximumHttpVersion() throws NoSuchFieldException, IllegalAccessException { + NettyHttpClientBuilder builder = new NettyHttpClientBuilder(); + + NettyHttpClient clientv1 = (NettyHttpClient) builder.maximumHttpVersion(HttpProtocolVersion.HTTP_1_1).build(); + Field httpVersionField = NettyHttpClient.class.getDeclaredField("maximumHttpVersion"); + httpVersionField.setAccessible(true); + assertEquals(HttpProtocolVersion.HTTP_1_1, httpVersionField.get(clientv1)); + + NettyHttpClient clientv2 = (NettyHttpClient) builder.maximumHttpVersion(null).build(); + assertEquals(HttpProtocolVersion.HTTP_2, httpVersionField.get(clientv2)); + } + private static Stream getEventLoopGroupToUseSupplier() throws ReflectiveOperationException { // Doesn't matter what this is calling, just needs to throw an exception. // This will as it doesn't accept the arguments that it will be called with. diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientTests.java index e31381e1edaa..b9380d3b7f7b 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientTests.java @@ -15,11 +15,16 @@ import java.util.concurrent.TimeUnit; /** - * Reactor Netty {@link HttpClientTests}. + * Netty {@link HttpClientTests}. */ @Timeout(value = 3, unit = TimeUnit.MINUTES) public class NettyHttpClientHttpClientTests extends HttpClientTests { private static LocalTestServer server; + private static final HttpClient HTTP_CLIENT_INSTANCE; + + static { + HTTP_CLIENT_INSTANCE = new NettyHttpClientBuilder().connectionPoolSize(0).build(); + } @BeforeAll public static void startTestServer() { @@ -29,6 +34,9 @@ public static void startTestServer() { @AfterAll public static void stopTestServer() { + if (HTTP_CLIENT_INSTANCE instanceof NettyHttpClient) { + ((NettyHttpClient) HTTP_CLIENT_INSTANCE).close(); + } if (server != null) { server.stop(); } @@ -47,6 +55,6 @@ protected String getServerUri(boolean secure) { @Override protected HttpClient getHttpClient() { - return new NettyHttpClientBuilder().build(); + return HTTP_CLIENT_INSTANCE; } } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientWithHttpsTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientWithHttpsTests.java index 38a00845ce1a..5904ada0299d 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientWithHttpsTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientHttpClientWithHttpsTests.java @@ -17,7 +17,7 @@ import java.util.concurrent.TimeUnit; /** - * Reactor Netty {@link HttpClientTests} with https. + * Netty {@link HttpClientTests} with https. * Some request logic branches out if it's https like file uploads. */ @Timeout(value = 3, unit = TimeUnit.MINUTES) @@ -28,6 +28,7 @@ public class NettyHttpClientHttpClientWithHttpsTests extends HttpClientTests { static { HTTP_CLIENT_INSTANCE = new NettyHttpClientBuilder() //.maximumHttpVersion(HttpProtocolVersion.HTTP_1_1) + .connectionPoolSize(0) .sslContextModifier(ssl -> ssl.trustManager(new InsecureTrustManager()).secureRandom(new SecureRandom())) .build(); } @@ -40,6 +41,9 @@ public static void startTestServer() { @AfterAll public static void stopTestServer() { + if (HTTP_CLIENT_INSTANCE instanceof NettyHttpClient) { + ((NettyHttpClient) HTTP_CLIENT_INSTANCE).close(); + } if (server != null) { server.stop(); } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientTests.java index 4ba9621b3c73..0798dbd90f74 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientTests.java @@ -4,6 +4,7 @@ package io.clientcore.http.netty4; import io.clientcore.core.http.client.HttpClient; +import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpHeader; import io.clientcore.core.http.models.HttpHeaderName; import io.clientcore.core.http.models.HttpHeaders; @@ -12,6 +13,7 @@ import io.clientcore.core.http.models.ProxyOptions; import io.clientcore.core.http.models.RequestContext; import io.clientcore.core.http.models.Response; +import io.clientcore.core.http.models.ServerSentEvent; import io.clientcore.core.http.pipeline.HttpPipeline; import io.clientcore.core.http.pipeline.HttpPipelineBuilder; import io.clientcore.core.http.pipeline.HttpPipelinePolicy; @@ -19,12 +21,15 @@ import io.clientcore.core.http.pipeline.HttpRetryPolicy; import io.clientcore.core.models.CoreException; import io.clientcore.core.models.binarydata.BinaryData; +import io.clientcore.core.shared.LocalTestServer; import io.clientcore.core.utils.ProgressReporter; import io.clientcore.http.netty4.implementation.MockProxyServer; import io.clientcore.http.netty4.implementation.Netty4ProgressAndTimeoutHandler; import io.clientcore.http.netty4.implementation.NettyHttpClientLocalTestServer; import io.netty.channel.ChannelPipeline; import io.netty.handler.proxy.ProxyConnectException; +import org.eclipse.jetty.io.EndPoint; +import org.eclipse.jetty.server.HttpConnection; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.RepeatedTest; @@ -34,15 +39,21 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import javax.net.ssl.SSLException; +import javax.servlet.http.HttpServletResponse; import java.io.ByteArrayOutputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.io.StringWriter; +import java.net.SocketException; import java.net.URI; import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.time.Duration; @@ -50,11 +61,13 @@ import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -302,18 +315,24 @@ public void failedProxyAuthenticationReturnsCorrectError() { () -> httpClient.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(uri(PROXY_TO_ADDRESS)))); Throwable exception = coreException.getCause(); - assertInstanceOf(ProxyConnectException.class, exception, () -> { - StringWriter stringWriter = new StringWriter(); - stringWriter.write(exception.toString()); - PrintWriter printWriter = new PrintWriter(stringWriter); - exception.printStackTrace(printWriter); - - return stringWriter.toString(); - }); - - assertTrue(coreException.getCause().getMessage().contains("Proxy Authentication Required"), - () -> "Expected exception message to contain \"Proxy Authentication Required\", it was: " - + coreException.getCause().getMessage()); + assertTrue(exception instanceof ProxyConnectException || exception instanceof ClosedChannelException, + "Exception was not of expected type ProxyConnectException or ClosedChannelException, but was " + + exception.getClass().getName()); + + if (exception instanceof ProxyConnectException) { + assertInstanceOf(ProxyConnectException.class, exception, () -> { + StringWriter stringWriter = new StringWriter(); + stringWriter.write(exception.toString()); + PrintWriter printWriter = new PrintWriter(stringWriter); + exception.printStackTrace(printWriter); + + return stringWriter.toString(); + }); + + assertTrue(coreException.getCause().getMessage().contains("Proxy Authentication Required"), + () -> "Expected exception message to contain \"Proxy Authentication Required\", it was: " + + coreException.getCause().getMessage()); + } } } @@ -356,6 +375,166 @@ public void progressAndTimeoutHandlerNotAdded() throws IOException { } } + @Test + public void sendWithServerSentEvents() throws InterruptedException { + LocalTestServer sseServer = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, (req, res, body) -> { + res.setContentType("text/event-stream"); + res.setCharacterEncoding(StandardCharsets.UTF_8.name()); + res.setStatus(HttpServletResponse.SC_OK); + try (PrintWriter writer = res.getWriter()) { + writer.println("id: 1"); + writer.println("event: message"); + writer.println("data: event-1"); + writer.println(); + writer.flush(); + + writer.println("id: 2"); + writer.println("event: message"); + writer.println("data: event-2"); + writer.println(); + writer.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + try { + sseServer.start(); + final CountDownLatch latch = new CountDownLatch(2); + final AtomicReference lastEvent = new AtomicReference<>(); + + HttpClient client = new NettyHttpClientBuilder().build(); + HttpRequest request = new HttpRequest().setMethod(HttpMethod.GET).setUri(URI.create(sseServer.getUri())); + request.setServerSentEventListener(event -> { + lastEvent.set(event); + latch.countDown(); + }); + + try (Response response = client.send(request)) { + assertEquals(200, response.getStatusCode()); + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertNotNull(lastEvent.get()); + assertEquals("2", lastEvent.get().getId()); + assertEquals("message", lastEvent.get().getEvent()); + } + } finally { + sseServer.stop(); + } + } + + @Test + public void sendWithServerSentEventsAndNoListenerThrows() { + LocalTestServer sseServer = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, (req, res, body) -> { + res.setContentType("text/event-stream"); + res.setStatus(HttpServletResponse.SC_OK); + try { + res.getWriter().println("data: event-1\n\n"); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + try { + sseServer.start(); + HttpClient client = new NettyHttpClientBuilder().build(); + HttpRequest request = new HttpRequest().setMethod(HttpMethod.GET).setUri(URI.create(sseServer.getUri())); + + IllegalStateException ex = assertThrows(IllegalStateException.class, () -> client.send(request)); + assertTrue(ex.getMessage().contains("No ServerSentEventListener attached")); + } finally { + sseServer.stop(); + } + } + + @Test + public void nonPooledClientSendsRequestSuccessfully() { + HttpClient client = new NettyHttpClientBuilder().connectionPoolSize(0).build(); + + try (Response response + = client.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(uri(SHORT_BODY_PATH)))) { + assertEquals(200, response.getStatusCode()); + assertArraysEqual(SHORT_BODY, response.getValue().toBytes()); + } + } + + @Test + public void nonPooledConnectionFails() { + HttpClient client = null; + try { + client = new NettyHttpClientBuilder().connectionPoolSize(0).build(); + HttpRequest request = new HttpRequest().setMethod(HttpMethod.GET).setUri("http://localhost:1"); + + HttpClient finalClient = client; + assertThrows(CoreException.class, () -> finalClient.send(request)); + } finally { + if (client != null) { + ((NettyHttpClient) client).close(); + } + } + } + + @Test + public void sslHandshakeFails() { + LocalTestServer server = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, + (req, res, body) -> res.setStatus(HttpServletResponse.SC_OK)); + HttpClient client = null; + try { + server.start(); + client = new NettyHttpClientBuilder().connectionPoolSize(0).build(); + + URI httpsUri = URI.create("https://localhost:" + server.getPort()); + + HttpRequest request = new HttpRequest().setMethod(HttpMethod.GET).setUri(httpsUri); + + HttpClient finalClient = client; + CoreException exception = assertThrows(CoreException.class, () -> finalClient.send(request)); + assertInstanceOf(SSLException.class, exception.getCause()); + } finally { + if (client != null) { + ((NettyHttpClient) client).close(); + } + server.stop(); + } + } + + @Test + public void requestWriteFailsWhenServerClosesConnection() { + LocalTestServer server = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, (req, res, body) -> { + try { + // Get the underlying java.nio.SocketChannel from the Jetty connection + EndPoint endPoint = HttpConnection.getCurrentConnection().getEndPoint(); + SocketChannel channel = (SocketChannel) endPoint.getTransport(); + + // Set SO_LINGER to true with a timeout of 0 seconds. + // This forces the OS to send a TCP RST packet on close() instead of the normal FIN sequence. + channel.socket().setSoLinger(true, 0); + } catch (SocketException e) { + throw new RuntimeException(e); + } + + // Now, close the connection. This will trigger the RST. + HttpConnection.getCurrentConnection().getEndPoint().close(); + }); + + HttpClient client = null; + try { + server.start(); + client = new NettyHttpClientBuilder().connectionPoolSize(0).build(); + HttpRequest request = new HttpRequest().setMethod(HttpMethod.POST) + .setUri(URI.create(server.getUri())) + .setBody(BinaryData.fromString("test data")); + + HttpClient finalClient = client; + CoreException exception = assertThrows(CoreException.class, () -> finalClient.send(request)); + assertInstanceOf(IOException.class, exception.getCause()); + } finally { + if (client != null) { + ((NettyHttpClient) client).close(); + } + server.stop(); + } + } + private static Stream requestHeaderSupplier() { return Stream.of(Arguments.of(null, NULL_REPLACEMENT), Arguments.of("", ""), Arguments.of("aValue", "aValue")); } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledConnectionsTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledConnectionsTests.java new file mode 100644 index 000000000000..604e2ad47587 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledConnectionsTests.java @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4; + +import io.clientcore.core.http.client.HttpClient; +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.core.shared.HttpClientTests; +import io.clientcore.core.shared.HttpClientTestsServer; +import io.clientcore.core.shared.LocalTestServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link NettyHttpClient} with a connection pool over plain HTTP. + */ +@Timeout(value = 3, unit = TimeUnit.MINUTES) +public class NettyHttpClientWithPooledConnectionsTests extends HttpClientTests { + private static LocalTestServer server; + private static HttpClient client; + + @BeforeAll + public static void startTestServer() { + server = HttpClientTestsServer.getHttpClientTestsServer(HttpProtocolVersion.HTTP_1_1, false); + server.start(); + + client = new NettyHttpClientBuilder().build(); + } + + @AfterAll + public static void stopTestServer() { + if (client instanceof NettyHttpClient) { + ((NettyHttpClient) client).close(); + } + if (server != null) { + server.stop(); + } + } + + @Override + protected int getPort() { + return server.getPort(); + } + + @Override + protected String getServerUri(boolean secure) { + return server.getUri(); + } + + @Override + protected HttpClient getHttpClient() { + return client; + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttp2ConnectionsTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttp2ConnectionsTests.java new file mode 100644 index 000000000000..8e2eefc312f0 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttp2ConnectionsTests.java @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4; + +import io.clientcore.core.http.client.HttpClient; +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.core.http.models.HttpMethod; +import io.clientcore.core.http.models.HttpRequest; +import io.clientcore.core.http.models.Response; +import io.clientcore.core.models.binarydata.BinaryData; +import io.clientcore.core.shared.HttpClientTests; +import io.clientcore.core.shared.HttpClientTestsServer; +import io.clientcore.core.shared.InsecureTrustManager; +import io.clientcore.core.shared.LocalTestServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.security.SecureRandom; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +/** + * Tests for {@link NettyHttpClient} with a connection pool using HTTP/2. + */ +@Timeout(value = 3, unit = TimeUnit.MINUTES) +public class NettyHttpClientWithPooledHttp2ConnectionsTests extends HttpClientTests { + private static LocalTestServer server; + private static HttpClient client; + + @BeforeAll + public static void startTestServer() { + server = HttpClientTestsServer.getHttpClientTestsServer(HttpProtocolVersion.HTTP_2, true); + server.start(); + + client = new NettyHttpClientBuilder().maximumHttpVersion(HttpProtocolVersion.HTTP_2) + .sslContextModifier( + builder -> builder.trustManager(new InsecureTrustManager()).secureRandom(new SecureRandom())) + .build(); + } + + @AfterAll + public static void stopTestServer() { + if (client instanceof NettyHttpClient) { + ((NettyHttpClient) client).close(); + } + if (server != null) { + server.stop(); + } + } + + @Override + protected boolean isHttp2() { + return true; + } + + @Override + protected boolean isSecure() { + return true; + } + + @Override + protected int getPort() { + return server.getPort(); + } + + @Override + protected String getServerUri(boolean secure) { + return server.getHttpsUri(); + } + + @Override + protected HttpClient getHttpClient() { + return client; + } + + @Test + public void canSendBinaryDataDebug() { + byte[] expectedBytes = new byte[1024 * 1024]; + ThreadLocalRandom.current().nextBytes(expectedBytes); + HttpRequest request = new HttpRequest().setMethod(HttpMethod.PUT) + .setUri(getRequestUri("echo")) + .setBody(BinaryData.fromBytes(expectedBytes)); + + try (Response response = getHttpClient().send(request)) { + assertArrayEquals(expectedBytes, response.getValue().toBytes()); + } + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttpsConnectionsTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttpsConnectionsTests.java new file mode 100644 index 000000000000..6bf1571b228f --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/NettyHttpClientWithPooledHttpsConnectionsTests.java @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4; + +import io.clientcore.core.http.client.HttpClient; +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.core.shared.HttpClientTests; +import io.clientcore.core.shared.HttpClientTestsServer; +import io.clientcore.core.shared.InsecureTrustManager; +import io.clientcore.core.shared.LocalTestServer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Timeout; + +import java.security.SecureRandom; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link NettyHttpClient} with a connection pool over HTTPS. + */ +@Timeout(value = 3, unit = TimeUnit.MINUTES) +public class NettyHttpClientWithPooledHttpsConnectionsTests extends HttpClientTests { + private static LocalTestServer server; + private static HttpClient client; + + @BeforeAll + public static void startTestServer() { + server = HttpClientTestsServer.getHttpClientTestsServer(HttpProtocolVersion.HTTP_1_1, true); + server.start(); + + client = new NettyHttpClientBuilder() + .sslContextModifier(ssl -> ssl.trustManager(new InsecureTrustManager()).secureRandom(new SecureRandom())) + .build(); + } + + @AfterAll + public static void stopTestServer() { + if (client instanceof NettyHttpClient) { + ((NettyHttpClient) client).close(); + } + if (server != null) { + server.stop(); + } + } + + @Override + protected int getPort() { + return server.getPort(); + } + + @Override + protected String getServerUri(boolean secure) { + return secure ? server.getHttpsUri() : server.getUri(); + } + + @Override + protected boolean isSecure() { + return true; + } + + @Override + protected HttpClient getHttpClient() { + return client; + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/TestUtils.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/TestUtils.java index 29b4f8ab0178..e807f15c6367 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/TestUtils.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/TestUtils.java @@ -2,6 +2,8 @@ // Licensed under the MIT License. package io.clientcore.http.netty4; +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.http.netty4.implementation.Netty4AlpnHandler; import io.clientcore.http.netty4.mocking.MockChannel; import io.netty.channel.Channel; import io.netty.channel.DefaultEventLoop; @@ -39,6 +41,19 @@ public static void assertArraysEqual(byte[] expected, byte[] actual) { * @return A {@link Channel}. */ public static Channel createChannelWithReadHandling(BiConsumer readHandler) { + return createChannelWithReadHandling(readHandler, null); + } + + /** + * Creates a {@link Channel} that is able to mock {@link Channel#read()} operations. + * + * @param readHandler A {@link BiConsumer} that takes the current read count and the channel and mocks reading + * operations. + * @param protocolVersion The HTTP protocol version to set on the channel's attributes. Can be null. + * @return A {@link Channel}. + */ + public static Channel createChannelWithReadHandling(BiConsumer readHandler, + HttpProtocolVersion protocolVersion) { EventLoop eventLoop = new DefaultEventLoop() { @Override public boolean inEventLoop(Thread thread) { @@ -61,6 +76,10 @@ public boolean isActive() { } }; + if (protocolVersion != null) { + channel.attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).set(protocolVersion); + } + try { eventLoop.register(channel).sync(); } catch (InterruptedException e) { diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/HttpResponseDrainsBufferTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/HttpResponseDrainsBufferTests.java index 5f38223d84d6..10a9e3769338 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/HttpResponseDrainsBufferTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/HttpResponseDrainsBufferTests.java @@ -4,10 +4,13 @@ package io.clientcore.http.netty4.implementation; import io.clientcore.core.http.client.HttpClient; +import io.clientcore.core.http.client.HttpProtocolVersion; import io.clientcore.core.http.models.HttpMethod; import io.clientcore.core.http.models.HttpRequest; import io.clientcore.core.http.models.Response; +import io.clientcore.core.models.CoreException; import io.clientcore.core.models.binarydata.BinaryData; +import io.clientcore.core.shared.LocalTestServer; import io.clientcore.core.utils.IOExceptionCheckedConsumer; import io.clientcore.core.utils.SharedExecutorService; import io.clientcore.http.netty4.NettyHttpClientProvider; @@ -24,6 +27,7 @@ import org.junit.jupiter.api.parallel.ExecutionMode; import org.junit.jupiter.api.parallel.Isolated; +import javax.servlet.ServletException; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; @@ -44,6 +48,8 @@ import static io.clientcore.http.netty4.implementation.NettyHttpClientLocalTestServer.LONG_BODY; import static io.clientcore.http.netty4.implementation.NettyHttpClientLocalTestServer.LONG_BODY_PATH; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * Tests that closing the {@link Response} drains the network buffers. @@ -56,7 +62,8 @@ @Execution(ExecutionMode.SAME_THREAD) public class HttpResponseDrainsBufferTests { private static ResourceLeakDetector.Level originalLevel; - private static final String URL = NettyHttpClientLocalTestServer.getServer().getUri() + LONG_BODY_PATH; + private static String url; + private static LocalTestServer server; private ResourceLeakDetectorFactory originalLeakDetectorFactory; private final TestResourceLeakDetectorFactory testResourceLeakDetectorFactory @@ -64,6 +71,25 @@ public class HttpResponseDrainsBufferTests { @BeforeAll public static void startTestServer() { + server = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, (req, resp, requestBody) -> { + if ("GET".equalsIgnoreCase(req.getMethod()) && LONG_BODY_PATH.equals(req.getServletPath())) { + resp.setStatus(200); + resp.setContentType("application/octet-stream"); + resp.setContentLength(LONG_BODY.length); + try { + resp.getOutputStream().write(LONG_BODY); + resp.flushBuffer(); + } catch (IOException e) { + throw new ServletException(e); + } + } else { + resp.sendError(404, "Endpoint not found."); + } + }); + + server.start(); + url = server.getUri() + LONG_BODY_PATH; + originalLevel = ResourceLeakDetector.getLevel(); ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); } @@ -81,7 +107,13 @@ public void resetLeakDetectorFactory() { @AfterAll public static void stopTestServer() { - ResourceLeakDetector.setLevel(originalLevel); + if (server != null) { + server.stop(); + } + + if (originalLevel != null) { + ResourceLeakDetector.setLevel(originalLevel); + } } @Test @@ -99,7 +131,11 @@ public void closeHttpResponseWithConsumingPartialBody() { @Test public void closeHttpResponseWithConsumingPartialWrite() { - runScenario(response -> response.getValue().writeTo(new ThrowingWritableByteChannel())); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> runScenario(response -> response.getValue().writeTo(new ThrowingWritableByteChannel()))); + assertInstanceOf(ExecutionException.class, ex.getCause()); + assertInstanceOf(CoreException.class, ex.getCause().getCause()); + assertEquals(0, testResourceLeakDetectorFactory.getTotalReportedLeakCount()); } private static final class ThrowingWritableByteChannel implements WritableByteChannel { @@ -163,7 +199,7 @@ private void runScenario(IOExceptionCheckedConsumer> respon try { limiter.acquire(); responseConsumer - .accept(httpClient.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(URL))); + .accept(httpClient.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(url))); } finally { limiter.release(); } @@ -184,15 +220,13 @@ private void runScenario(IOExceptionCheckedConsumer> respon } catch (InterruptedException | ExecutionException ex) { throw new RuntimeException(ex); } - - assertEquals(0, testResourceLeakDetectorFactory.getTotalReportedLeakCount()); } @Test public void closingHttpResponseIsIdempotent() throws InterruptedException { HttpClient httpClient = new NettyHttpClientProvider().getSharedInstance(); - Response response = httpClient.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(URL)); + Response response = httpClient.send(new HttpRequest().setMethod(HttpMethod.GET).setUri(url)); response.close(); Thread.sleep(1_000); response.close(); diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolTests.java new file mode 100644 index 000000000000..ba84eb3e9035 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ConnectionPoolTests.java @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.clientcore.core.models.CoreException; +import io.clientcore.core.shared.LocalTestServer; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Tests for {@link Netty4ConnectionPool}. + */ +@Timeout(value = 1, unit = TimeUnit.MINUTES) +public class Netty4ConnectionPoolTests { + + private static LocalTestServer server; + private static EventLoopGroup eventLoopGroup; + private static Bootstrap bootstrap; + private static Netty4ConnectionPoolKey connectionPoolKey; + + @BeforeAll + public static void startTestServerAndEventLoopGroup() { + server = NettyHttpClientLocalTestServer.getServer(); + server.start(); + eventLoopGroup = new NioEventLoopGroup(2); + bootstrap = new Bootstrap().group(eventLoopGroup).channel(NioSocketChannel.class); + bootstrap.option(ChannelOption.AUTO_READ, false); + SocketAddress socketAddress = new InetSocketAddress("localhost", server.getPort()); + connectionPoolKey = new Netty4ConnectionPoolKey(socketAddress, socketAddress); + } + + @AfterAll + public static void stopTestServerAndEventLoopGroup() { + if (server != null) { + server.stop(); + } + if (eventLoopGroup != null && !eventLoopGroup.isShuttingDown()) { + eventLoopGroup.shutdownGracefully().awaitUninterruptibly(); + } + } + + private Netty4ConnectionPool createPool(int maxConnections, Duration idleTimeout, Duration maxLifetime, + Duration pendingAcquireTimeout, int maxPendingAcquires) { + return new Netty4ConnectionPool(bootstrap, new ChannelInitializationProxyHandler(null), null, // No SSL context modifier needed + maxConnections, idleTimeout, maxLifetime, pendingAcquireTimeout, maxPendingAcquires, + HttpProtocolVersion.HTTP_1_1); + } + + @Test + public void releaseNullChannelDoesNotThrow() throws IOException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + assertDoesNotThrow(() -> pool.release(null)); + } + } + + @Test + public void releaseUnknownChannelClosesChannel() throws IOException { + Channel unknownChannel = new NioSocketChannel(); + eventLoopGroup.register(unknownChannel); + + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + pool.release(unknownChannel); + // The pool doesn't know this channel, so it should close it. + unknownChannel.closeFuture().awaitUninterruptibly(); + assertFalse(unknownChannel.isOpen()); + } + } + + @Test + public void releaseToClosedPoolClosesChannel() throws IOException { + Bootstrap realBootstrap = bootstrap.clone().remoteAddress(connectionPoolKey.getConnectionTarget()); + Netty4ConnectionPool pool = new Netty4ConnectionPool(realBootstrap, new ChannelInitializationProxyHandler(null), + null, 1, null, null, null, 1, null); + + Channel channel = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertTrue(channel.isActive()); + + pool.close(); + pool.release(channel); + + channel.closeFuture().awaitUninterruptibly(); + assertFalse(channel.isOpen()); + } + + @Test + public void testAcquireAndRelease() throws IOException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + Future future = pool.acquire(connectionPoolKey, false); + Channel channel = future.awaitUninterruptibly().getNow(); + assertNotNull(channel); + assertTrue(channel.isActive()); + pool.release(channel); + } + } + + @Test + public void closeIsIdempotent() throws IOException { + Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1); + pool.close(); + assertDoesNotThrow(pool::close); + } + + @Test + public void poolWithNoIdleTimeoutHasNoCleanupTask() + throws IOException, NoSuchFieldException, IllegalAccessException { + try (Netty4ConnectionPool pool = createPool(1, null, null, Duration.ofSeconds(10), 1)) { + Field cleanupTaskField = Netty4ConnectionPool.class.getDeclaredField("cleanupTask"); + cleanupTaskField.setAccessible(true); + assertNull(cleanupTaskField.get(pool)); + } + } + + @Test + public void pendingAcquireQueueIsFull() throws IOException { + try (Netty4ConnectionPool pool = createPool(1, null, null, Duration.ofSeconds(10), 1)) { + Channel channel = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertNotNull(channel); + + Future pendingFuture = pool.acquire(connectionPoolKey, false); + assertFalse(pendingFuture.isDone()); + + Future failedFuture = pool.acquire(connectionPoolKey, false); + assertTrue(failedFuture.isDone()); + assertFalse(failedFuture.isSuccess()); + assertInstanceOf(CoreException.class, failedFuture.cause()); + + pool.release(channel); + pendingFuture.awaitUninterruptibly(); + pool.release(pendingFuture.getNow()); + } + } + + @Test + public void cancelledPendingAcquireIsRemovedFromQueue() throws IOException, InterruptedException { + try (Netty4ConnectionPool pool = createPool(1, null, Duration.ofSeconds(5), Duration.ofSeconds(10), 1)) { + Channel channel1 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + + Future pendingFuture = pool.acquire(connectionPoolKey, false); + pendingFuture.cancel(true); + + Thread.sleep(100); + + pool.release(channel1); + + Channel channel2 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertSame(channel1, channel2); + + pool.release(channel2); + } + } + + @Test + public void testConnectionIsReusedForSameRemoteAddress() throws IOException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + Future future1 = pool.acquire(connectionPoolKey, false); + Channel channel1 = future1.awaitUninterruptibly().getNow(); + pool.release(channel1); + + Future future2 = pool.acquire(connectionPoolKey, false); + Channel channel2 = future2.awaitUninterruptibly().getNow(); + assertSame(channel1, channel2); + pool.release(channel2); + } + } + + @Test + public void testConnectionPoolSizeEnforced() throws IOException, InterruptedException { + final int maxConnections = 5; + try (Netty4ConnectionPool pool + = createPool(maxConnections, Duration.ofSeconds(10), null, Duration.ofSeconds(10), maxConnections)) { + List channels = new ArrayList<>(); + for (int i = 0; i < maxConnections; i++) { + channels.add(pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow()); + } + assertEquals(maxConnections, channels.size()); + + Future pendingFuture = pool.acquire(connectionPoolKey, false); + Thread.sleep(100); + assertFalse(pendingFuture.isDone()); + + pool.release(channels.get(0)); + Channel pendingChannel = pendingFuture.awaitUninterruptibly().getNow(); + assertSame(channels.get(0), pendingChannel); + + for (int i = 1; i < channels.size(); i++) { + pool.release(channels.get(i)); + } + } + } + + @Test + public void testPendingAcquireTimeout() throws IOException, InterruptedException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofMillis(100), 1)) { + Channel channel = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + + Future timeoutFuture = pool.acquire(connectionPoolKey, false); + + assertTrue(timeoutFuture.await(500, TimeUnit.MILLISECONDS)); + + assertFalse(timeoutFuture.isSuccess()); + assertInstanceOf(CoreException.class, timeoutFuture.cause()); + assertTrue(timeoutFuture.cause().getMessage().contains("Connection acquisition timed out")); + + pool.release(channel); + } + } + + @Test + public void testIdleConnectionIsCleanedUp() throws IOException, InterruptedException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + Channel channel1 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + pool.release(channel1); + Thread.sleep(31000); // Wait for cleanup task to run (interval is 30s) + assertFalse(channel1.isActive()); + + Channel channel2 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertNotSame(channel1, channel2); + pool.release(channel2); + } + } + + @Test + public void testMaxConnectionLifetimeEnforced() throws IOException, InterruptedException { + try (Netty4ConnectionPool pool + = createPool(1, Duration.ofSeconds(10), Duration.ofMillis(500), Duration.ofSeconds(10), 1)) { + Channel channel1 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + Thread.sleep(600); + pool.release(channel1); + Thread.sleep(100); // Give a moment for close to propagate + assertFalse(channel1.isActive()); + + Channel channel2 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertNotSame(channel1, channel2); + pool.release(channel2); + } + } + + @Test + public void testUnhealthyConnectionIsDiscarded() throws IOException { + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + Channel channel1 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + pool.release(channel1); + channel1.close().awaitUninterruptibly(); + + Channel channel2 = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + assertNotNull(channel2); + assertTrue(channel2.isActive()); + assertNotSame(channel1, channel2); + pool.release(channel2); + } + } + + @Test + public void testAcquireOnClosedPoolFails() throws IOException, InterruptedException { + Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1); + pool.close(); + Future future = pool.acquire(connectionPoolKey, false); + future.await(); + + assertFalse(future.isSuccess()); + assertInstanceOf(IllegalStateException.class, future.cause()); + } + + @Test + public void testSeparatePoolsForSeparateRemoteAddresses() throws IOException { + LocalTestServer route1Server = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, null); + LocalTestServer route2Server = new LocalTestServer(HttpProtocolVersion.HTTP_1_1, false, null); + + try { + route1Server.start(); + route2Server.start(); + + SocketAddress address1 = new InetSocketAddress("localhost", route1Server.getPort()); + SocketAddress address2 = new InetSocketAddress("localhost", route2Server.getPort()); + Netty4ConnectionPoolKey key1 = new Netty4ConnectionPoolKey(address1, address1); + Netty4ConnectionPoolKey key2 = new Netty4ConnectionPoolKey(address2, address2); + + try (Netty4ConnectionPool pool = createPool(1, Duration.ofSeconds(10), null, Duration.ofSeconds(10), 1)) { + Channel channel1 = pool.acquire(key1, false).awaitUninterruptibly().getNow(); + assertNotNull(channel1); + + Channel channel2 = pool.acquire(key2, false).awaitUninterruptibly().getNow(); + assertNotNull(channel2); + + assertNotSame(channel1, channel2); + + pool.release(channel1); + pool.release(channel2); + } + } finally { + route1Server.stop(); + route2Server.stop(); + } + } + + @Test + public void poolDoesNotDeadlockAndRecoversCleanlyUnderSaturation() throws InterruptedException, IOException { + final int poolSize = 10; + final int numThreads = 20; + final int numTasks = 100; + + final CountDownLatch latch = new CountDownLatch(numTasks); + final AtomicInteger successCounter = new AtomicInteger(0); + final Queue exceptions = new ConcurrentLinkedQueue<>(); + final ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + try (Netty4ConnectionPool pool + = createPool(poolSize, Duration.ofSeconds(10), null, Duration.ofSeconds(2), numTasks)) { + for (int i = 0; i < numTasks; i++) { + executor.submit(() -> { + try { + Channel channel = pool.acquire(connectionPoolKey, false).awaitUninterruptibly().getNow(); + + // Hold the connection for a short, random time to simulate work. + Thread.sleep(ThreadLocalRandom.current().nextInt(10, 50)); + + pool.release(channel); + successCounter.incrementAndGet(); + } catch (Throwable t) { + exceptions.add(t); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS), "Test deadlocked, not all tasks completed."); + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + + if (!exceptions.isEmpty()) { + fail("Test tasks threw exceptions: " + + exceptions.stream().map(Throwable::getMessage).collect(Collectors.joining(", "))); + } + assertEquals(numTasks, successCounter.get(), "Mismatch in the number of successful tasks."); + + // Use reflection to check the final state of the pool's queues. + assertDoesNotThrow(() -> { + Field poolField = Netty4ConnectionPool.class.getDeclaredField("pool"); + poolField.setAccessible(true); + @SuppressWarnings("unchecked") + ConcurrentMap routePools + = (ConcurrentMap) poolField.get(pool); + Netty4ConnectionPool.PerRoutePool perRoutePool = routePools.get(connectionPoolKey); + + Field idleField = Netty4ConnectionPool.PerRoutePool.class.getDeclaredField("idleConnections"); + idleField.setAccessible(true); + Deque idleConnections = (Deque) idleField.get(perRoutePool); + + Field pendingField = Netty4ConnectionPool.PerRoutePool.class.getDeclaredField("pendingAcquirers"); + pendingField.setAccessible(true); + Deque pendingAcquirers = (Deque) pendingField.get(perRoutePool); + + assertEquals(poolSize, idleConnections.size(), "Pool should be full of idle connections."); + assertTrue(pendingAcquirers.isEmpty(), "Pending acquirers queue should be empty."); + }); + } + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandlerTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandlerTests.java new file mode 100644 index 000000000000..6eeb35a06ca0 --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4EagerConsumeChannelHandlerTests.java @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http2.DefaultHttp2DataFrame; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for {@link Netty4EagerConsumeChannelHandler}. + */ +public class Netty4EagerConsumeChannelHandlerTests { + private static final byte[] HELLO_BYTES = "Hello".getBytes(StandardCharsets.UTF_8); + + @Test + public void syncDrainConsumesHttp1Content() throws InterruptedException { + ByteArrayOutputStream receivedBytes = new ByteArrayOutputStream(); + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, + buf -> buf.readBytes(receivedBytes, buf.readableBytes()), false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new DefaultHttpContent(Unpooled.wrappedBuffer(HELLO_BYTES))); + assertFalse(latch.await(50, TimeUnit.MILLISECONDS), "Latch should not count down on partial content."); + + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + assertTrue(latch.await(1, TimeUnit.SECONDS), "Latch should count down on last content."); + + assertArrayEquals(HELLO_BYTES, receivedBytes.toByteArray()); + assertNull(channel.pipeline().get(Netty4EagerConsumeChannelHandler.class)); + } + + @Test + public void syncDrainConsumesHttp2Content() throws InterruptedException { + ByteArrayOutputStream receivedBytes = new ByteArrayOutputStream(); + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, + buf -> buf.readBytes(receivedBytes, buf.readableBytes()), true); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(HELLO_BYTES), false)); + assertFalse(latch.await(50, TimeUnit.MILLISECONDS), "Latch should not count down on partial content."); + + channel.writeInbound(new DefaultHttp2DataFrame(true)); + assertTrue(latch.await(1, TimeUnit.SECONDS), "Latch should count down on last content."); + + assertArrayEquals(HELLO_BYTES, receivedBytes.toByteArray()); + assertNull(channel.pipeline().get(Netty4EagerConsumeChannelHandler.class)); + } + + @Test + public void asyncDrainCallsOnComplete() { + AtomicBoolean onCompleteCalled = new AtomicBoolean(false); + Runnable onComplete = () -> onCompleteCalled.set(true); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(onComplete, false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + + channel.runPendingTasks(); + + assertTrue(onCompleteCalled.get(), "onComplete should have been called."); + assertNull(channel.pipeline().get(Netty4EagerConsumeChannelHandler.class)); + } + + @Test + public void consumerExceptionIsCapturedByHandler() { + IOException testException = new IOException("test"); + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, buf -> { + throw testException; + }, false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + ByteBuf content = Unpooled.wrappedBuffer(HELLO_BYTES); + + channel.writeInbound(content); + + Throwable capturedException = handler.channelException(); + + assertNotNull(capturedException); + assertEquals(testException, capturedException); + + assertNull(channel.pipeline().get(Netty4EagerConsumeChannelHandler.class)); + } + + @Test + public void exceptionCaughtSignalsCompletion() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, buf -> { + }, false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.pipeline().fireExceptionCaught(new RuntimeException("test")); + + assertTrue(latch.await(1, TimeUnit.SECONDS), "Latch should count down on exceptionCaught."); + assertNotNull(handler.channelException()); + } + + @Test + public void channelInactiveSignalsCompletion() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, buf -> { + }, false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + assertTrue(channel.isActive()); + + channel.close().awaitUninterruptibly(); + + assertTrue(latch.await(1, TimeUnit.SECONDS), "Latch should count down on channelInactive."); + } + + @Test + public void addingToInactiveChannelFiresException() { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.close().awaitUninterruptibly(); + + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(() -> { + }, false); + + channel.pipeline().addLast(handler); + + assertThrows(ClosedChannelException.class, channel::checkException); + } + + @Test + public void handlesByteBufMessage() { + ByteArrayOutputStream receivedBytes = new ByteArrayOutputStream(); + CountDownLatch latch = new CountDownLatch(1); + Netty4EagerConsumeChannelHandler handler = new Netty4EagerConsumeChannelHandler(latch, + buf -> buf.readBytes(receivedBytes, buf.readableBytes()), false); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(Unpooled.wrappedBuffer(HELLO_BYTES)); + channel.finishAndReleaseAll(); + + assertEquals(0, latch.getCount()); + assertArrayEquals(HELLO_BYTES, receivedBytes.toByteArray()); + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelBinaryDataTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelBinaryDataTests.java index a51af8e6be19..02ccf7538c9b 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelBinaryDataTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelBinaryDataTests.java @@ -2,35 +2,68 @@ // Licensed under the MIT License. package io.clientcore.http.netty4.implementation; +import io.clientcore.core.models.CoreException; import io.clientcore.core.models.binarydata.BinaryData; import io.clientcore.core.models.binarydata.ByteArrayBinaryData; +import io.clientcore.core.serialization.ObjectSerializer; import io.clientcore.http.netty4.mocking.MockChannelHandlerContext; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoop; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpContent; import io.netty.handler.codec.http.LastHttpContent; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.Mockito; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Type; +import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import static io.clientcore.http.netty4.TestUtils.assertArraysEqual; import static io.clientcore.http.netty4.TestUtils.createChannelWithReadHandling; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Tests {@link Netty4ChannelBinaryData}. */ @Timeout(value = 3, unit = TimeUnit.MINUTES) public class Netty4Http11ChannelBinaryDataTests { + private static final byte[] HELLO_BYTES = "Hello".getBytes(StandardCharsets.UTF_8); + private static final byte[] WORLD_BYTES = " World!".getBytes(StandardCharsets.UTF_8); + private static final byte[] HELLO_WORLD_BYTES = "Hello World!".getBytes(StandardCharsets.UTF_8); + @Test public void toBytesWillThrowIsLengthIsTooLarge() { assertThrows(IllegalStateException.class, @@ -139,7 +172,7 @@ public void channelBinaryDataIsNeverReplayable() { @Test public void channelBinaryDataToReplayableReturnsAByteArrayBinaryData() throws IOException { - byte[] expected = "Hello world!".getBytes(StandardCharsets.UTF_8); + byte[] expected = HELLO_WORLD_BYTES; ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); eagerContent.write(expected); @@ -154,6 +187,360 @@ public void channelBinaryDataToReplayableReturnsAByteArrayBinaryData() throws IO assertArraysEqual(expected, replayable.toBytes()); } + @Test + public void toStreamReturnsNettyStreamWhenNotDrained() throws IOException { + ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); + eagerContent.write(HELLO_BYTES); + Channel channel = createChannelWithReadHandling((ignored, ch) -> { + ByteBuf content = Unpooled.wrappedBuffer(WORLD_BYTES); + ch.pipeline().fireChannelRead(content); + ch.pipeline().fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT); + ch.pipeline().fireChannelReadComplete(); + }); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(eagerContent, channel, (long) HELLO_WORLD_BYTES.length, false); + + InputStream stream = binaryData.toStream(); + + assertInstanceOf(Netty4ChannelInputStream.class, stream); + + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int length; + while ((length = stream.read(buffer)) != -1) { + result.write(buffer, 0, length); + } + assertArraysEqual(HELLO_WORLD_BYTES, result.toByteArray()); + } + + @Test + public void toBytesDrainsFromLiveChannel() throws IOException { + ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); + eagerContent.write(HELLO_BYTES); + + EmbeddedChannel channel = new EmbeddedChannel(); + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(eagerContent, channel, (long) HELLO_WORLD_BYTES.length, false, null); + + Thread serverThread = new Thread(() -> { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + channel.writeInbound(new DefaultHttpContent(Unpooled.wrappedBuffer(WORLD_BYTES))); + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + }); + + serverThread.start(); + + byte[] result = binaryData.toBytes(); + + assertArrayEquals(HELLO_WORLD_BYTES, result); + assertTrue(channel.config().isAutoRead()); + } + + @Test + public void toBytesThrowsIfChannelErrors() { + IOException testException = new IOException("test error"); + Channel channel = createChannelWithReadHandling((ignored, ch) -> { + ch.pipeline().addLast(new ExceptionSuppressingHandler()); + ch.pipeline().fireExceptionCaught(testException); + }); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 10L, false); + + CoreException exception = assertThrows(CoreException.class, binaryData::toBytes); + assertEquals(testException, exception.getCause()); + } + + @Test + public void closeAfterDrainingDisconnectsChannel() { + TestMockChannel realChannel = new TestMockChannel(); + new DefaultEventLoop().register(realChannel); + Channel spiedChannel = spy(realChannel); + Runnable cleanupTask = () -> { + spiedChannel.disconnect(); + spiedChannel.close(); + }; + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), spiedChannel, 0L, false, cleanupTask); + + binaryData.toBytes(); + binaryData.close(); + + verify(spiedChannel).disconnect(); + verify(spiedChannel).close(); + } + + @Test + public void toObjectThrowsCoreExceptionOnSerializationError() throws IOException { + ObjectSerializer mockSerializer = Mockito.mock(ObjectSerializer.class); + when(mockSerializer.deserializeFromBytes(any(), any(Type.class))) + .thenThrow(new IOException("deserialization failed")); + + TestMockChannel channel = new TestMockChannel(); + new DefaultEventLoop().register(channel); + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 0L, false); + + CoreException ex = assertThrows(CoreException.class, () -> binaryData.toObject(String.class, mockSerializer)); + assertInstanceOf(IOException.class, ex.getCause()); + } + + @Test + public void cleanupDoesNothingIfHandlerIsMissing() { + TestMockChannel realChannel = new TestMockChannel(); + new DefaultEventLoop().register(realChannel); + Channel spiedChannel = spy(realChannel); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), spiedChannel, 0L, false); + + binaryData.toBytes(); + + verify(spiedChannel, never()).close(); + } + + @Test + public void toBytesOnInactiveChannelReturnsEagerContent() throws IOException { + byte[] eagerBytes = "eager".getBytes(StandardCharsets.UTF_8); + ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); + eagerContent.write(eagerBytes); + + TestMockChannel channel = new TestMockChannel(); + new DefaultEventLoop().register(channel); + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(eagerContent, channel, (long) eagerBytes.length, false); + + channel.close().awaitUninterruptibly(); + byte[] result = binaryData.toBytes(); + + assertArraysEqual(eagerBytes, result); + } + + @Test + public void toBytesUsesEagerContentWhenSufficient() throws IOException { + byte[] fullBody = "Full body".getBytes(StandardCharsets.UTF_8); + ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); + eagerContent.write(fullBody); + + TestMockChannel realChannel = new TestMockChannel(); + new DefaultEventLoop().register(realChannel); + Channel spiedChannel = spy(realChannel); + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(eagerContent, spiedChannel, (long) fullBody.length, false); + + byte[] result = binaryData.toBytes(); + + assertArraysEqual(fullBody, result); + verify(spiedChannel, never()).read(); + verify(spiedChannel, never()).config(); + } + + @Test + public void closeBeforeDrainingEventuallyCleansUp() throws InterruptedException { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.isActive()); + + CountDownLatch cleanupLatch = new CountDownLatch(1); + Runnable cleanupTask = cleanupLatch::countDown; + + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 1L, false, cleanupTask); + + binaryData.close(); + + channel.close().awaitUninterruptibly(); + + assertTrue(cleanupLatch.await(10, TimeUnit.SECONDS), + "Cleanup task was not called after the channel became inactive."); + } + + @Test + public void testBinaryDataWithoutOnClose() { + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), new EmbeddedChannel(), 10L, false); + assertEquals(10L, binaryData.getLength()); + } + + @Test + public void writeToAlreadyDrainedStreamThrowsException() { + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channelWithNoData(), 0L, false); + + binaryData.writeTo(new ByteArrayOutputStream()); + + assertThrows(IllegalStateException.class, () -> binaryData.writeTo(new ByteArrayOutputStream())); + } + + @Test + public void writeToThrowsWhenChannelErrors() { + IOException testException = new IOException("test writeTo error"); + Channel channel = createChannelWithReadHandling((ignored, ch) -> { + ch.pipeline().addLast(new ExceptionSuppressingHandler()); + ch.pipeline().fireExceptionCaught(testException); + }); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 10L, false); + + CoreException exception + = assertThrows(CoreException.class, () -> binaryData.writeTo(new ByteArrayOutputStream())); + assertEquals(testException, exception.getCause()); + } + + @Test + public void writeToThrowsWhenChannelThrowsError() { + // This test covers the 'instanceof Error' branch in writeTo(OutputStream). + AssertionError testError = new AssertionError("test writeTo error"); + Channel channel = createChannelWithReadHandling((ignored, ch) -> { + ch.pipeline().addLast(new ExceptionSuppressingHandler()); + ch.pipeline().fireExceptionCaught(testError); + }); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 10L, false); + + AssertionError error + = assertThrows(AssertionError.class, () -> binaryData.writeTo(new ByteArrayOutputStream())); + assertEquals(testError, error); + } + + @Test + public void closeIsIdempotent() { + AtomicInteger closed = new AtomicInteger(0); + Runnable cleanupTask = closed::getAndIncrement; + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channelWithNoData(), 0L, false, cleanupTask); + + binaryData.toBytes(); + + binaryData.close(); + binaryData.close(); + + assertEquals(1, closed.get(), "Close should have been called only once"); + } + + @Test + public void toBytesThrowsOnInactiveChannelWithIncompleteBody() throws IOException { + // This test covers the case where the channel is closed but the eager content is insufficient. + byte[] eagerBytes = "eager".getBytes(StandardCharsets.UTF_8); + ByteArrayOutputStream eagerContent = new ByteArrayOutputStream(); + eagerContent.write(eagerBytes); + + TestMockChannel channel = new TestMockChannel(); + new DefaultEventLoop().register(channel); + + // The Expected length is 10, but we only have 5 bytes. + Netty4ChannelBinaryData binaryData = new Netty4ChannelBinaryData(eagerContent, channel, 10L, false); + + channel.close().awaitUninterruptibly(); + + CoreException exception = assertThrows(CoreException.class, binaryData::toBytes); + + assertInstanceOf(IOException.class, exception.getCause()); + } + + @Test + public void toBytesThrowsWhenChannelThrowsError() { + // This test covers the 'instanceof Error' branch in drainStream() used by toBytes(). + AssertionError testError = new AssertionError("test toBytes error"); + Channel channel = createChannelWithReadHandling((ignored, ch) -> { + ch.pipeline().addLast(new ExceptionSuppressingHandler()); + ch.pipeline().fireExceptionCaught(testError); + }); + Netty4ChannelBinaryData binaryData + = new Netty4ChannelBinaryData(new ByteArrayOutputStream(), channel, 10L, false); + + AssertionError error = assertThrows(AssertionError.class, binaryData::toBytes); + assertEquals(testError, error); + } + + private static class TestMockChannel extends AbstractChannel { + private final AtomicBoolean disconnectCalled = new AtomicBoolean(false); + private final AtomicBoolean closeCalled = new AtomicBoolean(false); + + protected TestMockChannel() { + super(null); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setSuccess(); + } + }; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return true; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + } + + @Override + protected void doDisconnect() { + disconnectCalled.set(true); + } + + @Override + protected void doClose() { + closeCalled.set(true); + } + + @Override + protected void doBeginRead() { + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + } + + @Override + public ChannelConfig config() { + return new DefaultChannelConfig(this); + } + + @Override + public boolean isOpen() { + return !closeCalled.get(); + } + + @Override + public boolean isActive() { + return !closeCalled.get(); + } + + @Override + public ChannelMetadata metadata() { + return new ChannelMetadata(false); + } + } + + private static final class ExceptionSuppressingHandler extends ChannelInboundHandlerAdapter { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + } + } + private static Channel channelWithNoData() { return createChannelWithReadHandling((ignored, channel) -> { Netty4EagerConsumeChannelHandler handler = channel.pipeline().get(Netty4EagerConsumeChannelHandler.class); diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelInputStreamTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelInputStreamTests.java index 67df3b7ac43e..3563d03c3813 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelInputStreamTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4Http11ChannelInputStreamTests.java @@ -23,7 +23,6 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import static io.clientcore.http.netty4.TestUtils.assertArraysEqual; @@ -31,7 +30,9 @@ import static io.netty.buffer.Unpooled.wrappedBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Tests {@link Netty4ChannelInputStream}. @@ -39,17 +40,17 @@ @Timeout(value = 3, unit = TimeUnit.MINUTES) public class Netty4Http11ChannelInputStreamTests { @Test - public void nullEagerContentResultsInEmptyInitialCurrentBuffer() { + public void nullEagerContentResultsInEmptyInitialCurrentBuffer() throws IOException { try (Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(null, createCloseableChannel(), false)) { + = new Netty4ChannelInputStream(null, createCloseableChannel(), false, null)) { assertEquals(0, channelInputStream.getCurrentBuffer().length); } } @Test - public void emptyEagerContentResultsInEmptyInitialCurrentBuffer() { + public void emptyEagerContentResultsInEmptyInitialCurrentBuffer() throws IOException { try (Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(new ByteArrayOutputStream(), createCloseableChannel(), false)) { + = new Netty4ChannelInputStream(new ByteArrayOutputStream(), createCloseableChannel(), false, null)) { assertEquals(0, channelInputStream.getCurrentBuffer().length); } } @@ -64,7 +65,7 @@ public void readConsumesCurrentBufferAndHasNoMoreData() throws IOException { // MockChannels aren't active by default, so once the eagerContent is consumed the stream will be done. Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(eagerContent, new MockChannel(), false); + = new Netty4ChannelInputStream(eagerContent, new MockChannel(), false, null); // Make sure the Netty4ChannelInputStream copied the eager content correctly. assertArraysEqual(expected, channelInputStream.getCurrentBuffer()); @@ -96,7 +97,7 @@ public void readConsumesCurrentBufferAndRequestsMoreData() throws IOException { handler.channelRead(ctx, wrappedBuffer(expected, 16, 16)); handler.channelRead(ctx, LastHttpContent.EMPTY_LAST_CONTENT); handler.channelReadComplete(ctx); - }), false); + }), false, null); int index = 0; byte[] actual = new byte[32]; @@ -118,7 +119,7 @@ public void multipleSmallerSkips() throws IOException { // MockChannels aren't active by default, so once the eagerContent is consumed the stream will be done. try (Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(eagerContent, createCloseableChannel(), false)) { + = new Netty4ChannelInputStream(eagerContent, createCloseableChannel(), false, null)) { long skipped = channelInputStream.skip(16); assertEquals(16, skipped); @@ -141,7 +142,7 @@ public void largeReadTriggersMultipleChannelReads() throws IOException { ThreadLocalRandom.current().nextBytes(expected); try (Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(null, createChannelThatReads8Kb(expected), false)) { + = new Netty4ChannelInputStream(null, createChannelThatReads8Kb(expected), false, null)) { byte[] actual = new byte[8192]; int read = channelInputStream.read(actual); @@ -162,7 +163,7 @@ public void largeSkipTriggersMultipleChannelReads() throws IOException { ThreadLocalRandom.current().nextBytes(expected); try (Netty4ChannelInputStream channelInputStream - = new Netty4ChannelInputStream(null, createChannelThatReads8Kb(expected), false)) { + = new Netty4ChannelInputStream(null, createChannelThatReads8Kb(expected), false, null)) { long skipped = channelInputStream.skip(8192); assertEquals(8192, skipped); @@ -172,21 +173,22 @@ public void largeSkipTriggersMultipleChannelReads() throws IOException { } @Test - public void closingStreamClosesChannel() { - AtomicInteger closeCount = new AtomicInteger(); - AtomicInteger disconnectCount = new AtomicInteger(); + public void closingStreamTriggersOnCloseCallback() throws IOException { + AtomicBoolean onCloseCalled = new AtomicBoolean(false); - new Netty4ChannelInputStream(null, - createCloseableChannel(closeCount::incrementAndGet, disconnectCount::incrementAndGet), false).close(); + try (Netty4ChannelInputStream channelInputStream + = new Netty4ChannelInputStream(null, createCloseableChannel(), false, () -> onCloseCalled.set(true))) { + assertNotNull(channelInputStream); + } - assertEquals(1, closeCount.get()); + assertTrue(onCloseCalled.get()); } @ParameterizedTest @MethodSource("errorSupplier") public void streamPropagatesErrorFiredInChannel(Throwable expected) { InputStream inputStream - = new Netty4ChannelInputStream(null, createPartialReadThenErrorChannel(expected), false); + = new Netty4ChannelInputStream(null, createPartialReadThenErrorChannel(expected), false, null); Throwable actual = assertThrows(Throwable.class, () -> inputStream.read(new byte[8192])); diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandlerTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandlerTests.java new file mode 100644 index 000000000000..afd24f2ab29a --- /dev/null +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4PipelineCleanupHandlerTests.java @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package io.clientcore.http.netty4.implementation; + +import io.clientcore.core.http.client.HttpProtocolVersion; +import io.netty.channel.AbstractChannel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoop; +import io.netty.util.AttributeKey; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_CODEC; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.HTTP_RESPONSE; +import static io.clientcore.http.netty4.implementation.Netty4HandlerNames.PROGRESS_AND_TIMEOUT; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link Netty4PipelineCleanupHandler}. + */ +public class Netty4PipelineCleanupHandlerTests { + + @Mock + private Netty4ConnectionPool connectionPool; + + private static final Object OBJECT = new Object(); + private TestMockChannel testChannel; + private AtomicReference errorReference; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + testChannel = new TestMockChannel(new MockEventLoop()); + testChannel.attr(AttributeKey.valueOf("channel-lock")).set(new ReentrantLock()); + testChannel.attr(AttributeKey.valueOf("pipeline-owner-token")).set(OBJECT); + testChannel.config.setAutoRead(false); + errorReference = new AtomicReference<>(); + } + + @Test + public void cleanupWhenPooledAndActiveReleasesChannel() { + testChannel.setActive(true); + testChannel.pipeline().addLast(HTTP_CODEC, new MockChannelHandler()); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, false); + + verify(connectionPool).release(testChannel); + assertEquals(0, testChannel.getCloseCallCount()); + assertNull(testChannel.pipeline().get(HTTP_CODEC)); + assertFalse(testChannel.config().isAutoRead()); + } + + @Test + public void cleanupWhenForceCloseClosesChannel() { + testChannel.setActive(true); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, true); + + assertEquals(1, testChannel.getCloseCallCount()); + verify(connectionPool, never()).release(testChannel); + } + + @Test + public void cleanupWhenNonPooledClosesChannel() { + testChannel.setActive(true); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(null, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, false); + + assertEquals(1, testChannel.getCloseCallCount()); + } + + @Test + public void cleanupWhenChannelInactiveClosesChannel() { + testChannel.setActive(false); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, false); + + assertEquals(1, testChannel.getCloseCallCount()); + verify(connectionPool, never()).release(testChannel); + } + + @Test + public void cleanupWhenHttp2PreservesHttpCodec() { + testChannel.setActive(true); + testChannel.attr(Netty4AlpnHandler.HTTP_PROTOCOL_VERSION_KEY).set(HttpProtocolVersion.HTTP_2); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + populatePipelineWithStandardHandlers(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, false); + + assertNotNull(testChannel.pipeline().get(HTTP_CODEC)); + assertNull(testChannel.pipeline().get(HTTP_RESPONSE)); + verify(connectionPool).release(testChannel); + } + + @Test + public void cleanupIsIdempotent() { + testChannel.setActive(true); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + + handler.cleanup(ctx, true); + handler.cleanup(ctx, true); + + assertEquals(1, testChannel.getCloseCallCount()); + } + + @Test + public void exceptionCaughtSetsErrorAndClosesChannel() { + testChannel.setActive(true); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + Throwable testException = new IOException("Test Exception"); + + handler.exceptionCaught(ctx, testException); + + assertEquals(testException, errorReference.get()); + assertEquals(1, testChannel.getCloseCallCount()); + verify(connectionPool, never()).release(testChannel); + } + + @Test + public void exceptionCaughtStillClosesChannel() { + testChannel.setActive(true); + Netty4PipelineCleanupHandler handler + = new Netty4PipelineCleanupHandler(connectionPool, new AtomicReference<>(), OBJECT); + testChannel.pipeline().addLast(handler); + ChannelHandlerContext ctx = testChannel.pipeline().context(handler); + Throwable testException = new IOException("Test Exception"); + + handler.exceptionCaught(ctx, testException); + + assertEquals(1, testChannel.getCloseCallCount()); + verify(connectionPool, never()).release(testChannel); + } + + @Test + public void channelInactiveSchedulesAndExecutesCleanup() { + testChannel.setActive(true); + assertTrue(testChannel.isActive()); + Netty4PipelineCleanupHandler handler = new Netty4PipelineCleanupHandler(connectionPool, errorReference, OBJECT); + testChannel.pipeline().addLast(handler); + + testChannel.close(); + + assertEquals(1, testChannel.getCloseCallCount(), "close() should be called once."); + verify(connectionPool, never()).release(testChannel); + } + + private void populatePipelineWithStandardHandlers(Netty4PipelineCleanupHandler handler) { + testChannel.pipeline().addLast(PROGRESS_AND_TIMEOUT, new MockChannelHandler()); + testChannel.pipeline().addLast(HTTP_RESPONSE, new MockChannelHandler()); + testChannel.pipeline().addLast(HTTP_CODEC, new MockChannelHandler()); + testChannel.pipeline().addLast(handler); + } + + private static class MockChannelHandler extends ChannelHandlerAdapter { + } + + private static class TestMockChannel extends AbstractChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private final ChannelConfig config = new DefaultChannelConfig(this); + private final AtomicInteger closeCallCount = new AtomicInteger(0); + private final EventLoop eventLoop; + + private volatile boolean active; + private volatile boolean open = true; + + protected TestMockChannel(EventLoop eventLoop) { + super(null); + this.eventLoop = eventLoop; + } + + @Override + public EventLoop eventLoop() { + return this.eventLoop; + } + + @Override + public ChannelConfig config() { + return this.config; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public boolean isActive() { + return active; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + active = true; + promise.setSuccess(); + } + }; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop == this.eventLoop; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + } + + @Override + protected void doDisconnect() { + active = false; + } + + @Override + protected void doClose() { + active = false; + open = false; + closeCallCount.incrementAndGet(); + } + + @Override + protected void doBeginRead() { + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + } + + public void setActive(boolean isActive) { + this.active = isActive; + } + + public int getCloseCallCount() { + return closeCallCount.get(); + } + } + + private static class MockEventLoop extends DefaultEventLoop { + @Override + public void execute(Runnable task) { + if (task == null) { + throw new NullPointerException("task"); + } + task.run(); + } + } +} diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandlerTests.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandlerTests.java index 8b01df09c2a8..d1808e26490a 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandlerTests.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/implementation/Netty4ResponseHandlerTests.java @@ -16,8 +16,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -66,98 +67,87 @@ public void firstReadIsFullHttpResponse() throws Exception { } @Test - public void incompleteIgnoredResponseBody() { - byte[] ignoredBodyData = new byte[32]; - ThreadLocalRandom.current().nextBytes(ignoredBodyData); + public void incompleteIgnoredResponseBody() throws InterruptedException { + CountDownLatch headersLatch = new CountDownLatch(1); + Netty4ResponseHandler responseHandler = new Netty4ResponseHandler(new HttpRequest().setMethod(HttpMethod.HEAD), + new AtomicReference<>(), new AtomicReference<>(), headersLatch); - HttpRequest request = new HttpRequest().setMethod(HttpMethod.HEAD); - AtomicReference responseReference = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - - Netty4ResponseHandler responseHandler - = new Netty4ResponseHandler(request, responseReference, new AtomicReference<>(), latch); + CountDownLatch bodyLatch = new CountDownLatch(1); Channel ch = createChannelWithReadHandling((readCount, channel) -> { - if (readCount == 0) { - Netty4ResponseHandler handler = channel.pipeline().get(Netty4ResponseHandler.class); - MockChannelHandlerContext ctx = new MockChannelHandlerContext(channel); - try { - handler.channelRead(ctx, - new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, new DefaultHttpHeaders())); - handler.channelReadComplete(ctx); - } catch (Exception ex) { - ctx.fireExceptionCaught(ex); + try { + if (readCount == 0) { + responseHandler.channelRead(new MockChannelHandlerContext(channel), + new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + responseHandler.channelReadComplete(new MockChannelHandlerContext(channel)); + } else { + Netty4EagerConsumeChannelHandler eagerConsumer + = channel.pipeline().get(Netty4EagerConsumeChannelHandler.class); + eagerConsumer.channelRead(new MockChannelHandlerContext(channel), + LastHttpContent.EMPTY_LAST_CONTENT); + eagerConsumer.channelReadComplete(new MockChannelHandlerContext(channel)); } - } else { - Netty4EagerConsumeChannelHandler handler - = channel.pipeline().get(Netty4EagerConsumeChannelHandler.class); - MockChannelHandlerContext ctx = new MockChannelHandlerContext(channel); - handler.channelRead(ctx, Unpooled.wrappedBuffer(ignoredBodyData)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(ignoredBodyData)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(ignoredBodyData)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(ignoredBodyData)); - handler.channelRead(ctx, LastHttpContent.EMPTY_LAST_CONTENT); - handler.channelReadComplete(ctx); + } catch (Exception e) { + channel.pipeline().fireExceptionCaught(e); } }); ch.pipeline().addLast(responseHandler); + ch.read(); + assertTrue(headersLatch.await(10, TimeUnit.SECONDS)); - assertEquals(0, latch.getCount()); + ch.pipeline().addLast(new Netty4EagerConsumeChannelHandler(bodyLatch, ignored -> { + }, false)); - ResponseStateInfo info = responseReference.get(); - assertNotNull(info); + ch.read(); + assertTrue(bodyLatch.await(10, TimeUnit.SECONDS)); } @Test - public void bufferedResponseBodyLargerThanInitialRead() { - byte[] bodyPieces = new byte[32]; - ThreadLocalRandom.current().nextBytes(bodyPieces); - - byte[] expectedBody = new byte[bodyPieces.length * 4]; - System.arraycopy(bodyPieces, 0, expectedBody, 0, bodyPieces.length); - System.arraycopy(bodyPieces, 0, expectedBody, bodyPieces.length, bodyPieces.length); - System.arraycopy(bodyPieces, 0, expectedBody, bodyPieces.length * 2, bodyPieces.length); - System.arraycopy(bodyPieces, 0, expectedBody, bodyPieces.length * 3, bodyPieces.length); - - HttpRequest request = new HttpRequest(); + public void bufferedResponseBodyLargerThanInitialRead() throws InterruptedException { AtomicReference responseReference = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); + CountDownLatch headersLatch = new CountDownLatch(1); Netty4ResponseHandler responseHandler - = new Netty4ResponseHandler(request, responseReference, new AtomicReference<>(), latch); + = new Netty4ResponseHandler(new HttpRequest(), responseReference, new AtomicReference<>(), headersLatch); + + CountDownLatch bodyLatch = new CountDownLatch(1); Channel ch = createChannelWithReadHandling((readCount, channel) -> { - if (readCount == 0) { - Netty4ResponseHandler handler = channel.pipeline().get(Netty4ResponseHandler.class); - MockChannelHandlerContext ctx = new MockChannelHandlerContext(channel); - try { - handler.channelRead(ctx, - new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, new DefaultHttpHeaders())); - handler.channelReadComplete(ctx); - } catch (Exception ex) { - ctx.fireExceptionCaught(ex); + try { + if (readCount == 0) { + responseHandler.channelRead(new MockChannelHandlerContext(channel), + new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + responseHandler.channelReadComplete(new MockChannelHandlerContext(channel)); + } else { + Netty4EagerConsumeChannelHandler eagerConsumer + = channel.pipeline().get(Netty4EagerConsumeChannelHandler.class); + eagerConsumer.channelRead(new MockChannelHandlerContext(channel), + LastHttpContent.EMPTY_LAST_CONTENT); + eagerConsumer.channelReadComplete(new MockChannelHandlerContext(channel)); } - } else { - Netty4EagerConsumeChannelHandler handler - = channel.pipeline().get(Netty4EagerConsumeChannelHandler.class); - MockChannelHandlerContext ctx = new MockChannelHandlerContext(channel); - handler.channelRead(ctx, Unpooled.wrappedBuffer(bodyPieces)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(bodyPieces)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(bodyPieces)); - handler.channelRead(ctx, Unpooled.wrappedBuffer(bodyPieces)); - handler.channelRead(ctx, LastHttpContent.EMPTY_LAST_CONTENT); - handler.channelReadComplete(ctx); + } catch (Exception e) { + channel.pipeline().fireExceptionCaught(e); } }); ch.pipeline().addLast(responseHandler); - ch.read(); - - assertEquals(0, latch.getCount()); + ch.read(); + assertTrue(headersLatch.await(10, TimeUnit.SECONDS)); ResponseStateInfo info = responseReference.get(); assertNotNull(info); + + ch.pipeline().addLast(new Netty4EagerConsumeChannelHandler(bodyLatch, buf -> { + try { + buf.readBytes(info.getEagerContent(), buf.readableBytes()); + } catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }, false)); + + ch.read(); + assertTrue(bodyLatch.await(10, TimeUnit.SECONDS)); } } diff --git a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/mocking/MockChannel.java b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/mocking/MockChannel.java index 934fe564a5c9..f2b0e8057b81 100644 --- a/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/mocking/MockChannel.java +++ b/sdk/clientcore/http-netty4/src/test/java/io/clientcore/http/netty4/mocking/MockChannel.java @@ -11,12 +11,14 @@ import io.netty.channel.EventLoop; import io.netty.util.Attribute; import io.netty.util.AttributeKey; +import io.netty.util.DefaultAttributeMap; import java.net.SocketAddress; public class MockChannel extends AbstractChannel { private final ChannelConfig config; private final ChannelMetadata metadata = new ChannelMetadata(false); + private final DefaultAttributeMap attributes = new DefaultAttributeMap(); public MockChannel() { super(null); @@ -25,7 +27,7 @@ public MockChannel() { @Override public Attribute attr(AttributeKey key) { - return null; + return attributes.attr(key); } @Override