diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d7a09b599794..3a217a2753f7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -85,6 +85,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) // All the Python functions should have the same exec, version and envvars. @@ -140,6 +141,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (workerMemoryMb.isDefined) { envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 527d0d6d3a48..33849f6fcb65 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -85,4 +85,8 @@ private[spark] object PythonUtils { def getBroadcastThreshold(sc: JavaSparkContext): Long = { sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) } + + def getPythonAuthSocketTimeout(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 188d88431964..348a33e129d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -50,4 +50,10 @@ private[spark] object Python { .version("2.4.0") .bytesConf(ByteUnit.MiB) .createOptional + + val PYTHON_AUTH_SOCKET_TIMEOUT = ConfigBuilder("spark.python.authenticate.socketTimeout") + .internal() + .version("3.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index dbcb37690533..f800553c5388 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils * * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. */ -private[spark] class SocketAuthHelper(conf: SparkConf) { +private[spark] class SocketAuthHelper(val conf: SparkConf) { val secret = Utils.createSecret(conf) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 548fd1b07ddc..35990b5a5928 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration import scala.util.Try import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,11 +36,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} * handling one batch of data, with authentication and error handling. * * The socket server can only accept one connection, or close if no connection - * in 15 seconds. + * in configurable amount of seconds (default 15). */ private[spark] abstract class SocketAuthServer[T]( authHelper: SocketAuthHelper, - threadName: String) { + threadName: String) extends Logging { def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) def this(threadName: String) = this(SparkEnv.get, threadName) @@ -46,19 +48,26 @@ private[spark] abstract class SocketAuthServer[T]( private val promise = Promise[T]() private def startServer(): (Int, String) = { + logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + // Close the socket if no connection in the configured seconds + val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt + logTrace(s"Setting timeout to $timeout sec") + serverSocket.setSoTimeout(timeout * 1000) new Thread(threadName) { setDaemon(true) override def run(): Unit = { var sock: Socket = null try { + logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}") sock = serverSocket.accept() + logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}") authHelper.authClient(sock) + logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) } finally { + logTrace("Closing server") JavaUtils.closeQuietly(serverSocket) JavaUtils.closeQuietly(sock) } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9c9e3f4b3c88..1bd5961e0525 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # data via a socket. # scala's mangled names w/ $ in them require special treatment. self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) + os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \ + str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index eafa5d90f9ff..fe2e326dff8b 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(15) + sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret)