From d7f6dc77b7e16e351e4ed5932250300adf4e090d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 21 Nov 2024 10:17:36 -0800 Subject: [PATCH] [SPARK-50383][CORE] Support Virtual Threads in REST Submission API --- .../deploy/rest/RestSubmissionServer.scala | 9 ++++++++- .../spark/internal/config/package.scala | 7 +++++++ .../rest/StandaloneRestSubmitSuite.scala | 20 ++++++++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index d3381ef6fb7f..877349da18dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.rest import java.util.EnumSet +import java.util.concurrent.{Executors, ExecutorService} import scala.io.Source @@ -33,7 +34,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ -import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS} +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS, MASTER_REST_SERVER_VIRTUAL_THREADS} import org.apache.spark.util.Utils /** @@ -93,6 +94,12 @@ private[spark] abstract class RestSubmissionServer( */ private def doStart(startPort: Int): (Server, Int) = { val threadPool = new QueuedThreadPool(masterConf.get(MASTER_REST_SERVER_MAX_THREADS)) + if (Utils.isJavaVersionAtLeast21 && masterConf.get(MASTER_REST_SERVER_VIRTUAL_THREADS)) { + val newVirtualThreadPerTaskExecutor = + classOf[Executors].getMethod("newVirtualThreadPerTaskExecutor") + val service = newVirtualThreadPerTaskExecutor.invoke(null).asInstanceOf[ExecutorService] + threadPool.setVirtualThreadsExecutor(service) + } threadPool.setDaemon(true) val server = new Server(threadPool) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index c2b49d164ae3..324ef701c426 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2000,6 +2000,13 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val MASTER_REST_SERVER_VIRTUAL_THREADS = + ConfigBuilder("spark.master.rest.virtualThread.enabled") + .doc("If true, Spark master tries to use Java 21 virtual thread for REST API.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val MASTER_UI_PORT = ConfigBuilder("spark.master.ui.port") .version("1.1.0") .intConf diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 075a15063c98..a155e4cc3ac9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -25,6 +25,7 @@ import java.util.Base64 import scala.collection.mutable import jakarta.servlet.http.HttpServletResponse +import org.eclipse.jetty.util.thread.QueuedThreadPool import org.eclipse.jetty.util.thread.ThreadPool.SizedThreadPool import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -34,7 +35,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ import org.apache.spark.deploy.master.RecoveryState -import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS} +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS, MASTER_REST_SERVER_VIRTUAL_THREADS} import org.apache.spark.rpc._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -559,6 +560,23 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(pool.getMaxThreads === 2000) } + test("SPARK-50383: Support spark.master.rest.virtualThread.enabled") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-virtualThreads", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_VIRTUAL_THREADS, true) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + val pool = server.get._server.get.getThreadPool.asInstanceOf[QueuedThreadPool] + if (Utils.isJavaVersionAtLeast21) { + assert(pool.getVirtualThreadsExecutor != null) + } else { + assert(pool.getVirtualThreadsExecutor == null) + } + } + /* --------------------- * | Helper methods | * --------------------- */