diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 7efab73726ef8..d3381ef6fb7f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -33,7 +33,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ -import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS} import org.apache.spark.util.Utils /** @@ -63,7 +63,8 @@ private[spark] abstract class RestSubmissionServer( protected val clearRequestServlet: ClearRequestServlet protected val readyzRequestServlet: ReadyzRequestServlet - private var _server: Option[Server] = None + // Visible for testing + private[rest] var _server: Option[Server] = None // A mapping from URL prefixes to servlets that serve them. Exposed for testing. protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" @@ -91,7 +92,7 @@ private[spark] abstract class RestSubmissionServer( * Return a 2-tuple of the started server and the bound port. */ private def doStart(startPort: Int): (Server, Int) = { - val threadPool = new QueuedThreadPool + val threadPool = new QueuedThreadPool(masterConf.get(MASTER_REST_SERVER_MAX_THREADS)) threadPool.setDaemon(true) val server = new Server(threadPool) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index c58c371da20cf..c2b49d164ae3e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1987,6 +1987,12 @@ package object config { .intConf .createWithDefault(6066) + private[spark] val MASTER_REST_SERVER_MAX_THREADS = ConfigBuilder("spark.master.rest.maxThreads") + .doc("Maximum number of threads to use in the Spark Master REST API Server.") + .version("4.0.0") + .intConf + .createWithDefault(200) + private[spark] val MASTER_REST_SERVER_FILTERS = ConfigBuilder("spark.master.rest.filters") .doc("Comma separated list of filter class names to apply to the Spark Master REST API.") .version("4.0.0") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 4a05aab01cb50..075a15063c981 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -25,6 +25,7 @@ import java.util.Base64 import scala.collection.mutable import jakarta.servlet.http.HttpServletResponse +import org.eclipse.jetty.util.thread.ThreadPool.SizedThreadPool import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -33,7 +34,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ import org.apache.spark.deploy.master.RecoveryState -import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS} import org.apache.spark.rpc._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -545,6 +546,19 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { } } + test("SPARK-50381: Support spark.master.rest.maxThreads") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-maxThreads", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_MAX_THREADS, 2000) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + val pool = server.get._server.get.getThreadPool.asInstanceOf[SizedThreadPool] + assert(pool.getMaxThreads === 2000) + } + /* --------------------- * | Helper methods | * --------------------- */