From 09b1235da280152ea3c3ecaa1a5cbe0386b7bf51 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 16 Nov 2020 16:05:27 +0100 Subject: [PATCH 01/14] [SPARK-33143][PYTHON] Add configurable timeout to python server and client --- .../apache/spark/api/python/PythonRunner.scala | 2 ++ .../apache/spark/internal/config/Python.scala | 5 +++++ .../apache/spark/security/SocketAuthServer.scala | 16 ++++++++++++---- python/pyspark/java_gateway.py | 2 +- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d7a09b599794..7da1f50a0b3e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -85,6 +85,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val gatewayConnectTimeout = conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) // All the Python functions should have the same exec, version and envvars. @@ -140,6 +141,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (workerMemoryMb.isDefined) { envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) } + envVars.put("SPARK_GATEWAY_CONNECT_TIMEOUT", gatewayConnectTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 188d88431964..b3542e87c634 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -50,4 +50,9 @@ private[spark] object Python { .version("2.4.0") .bytesConf(ByteUnit.MiB) .createOptional + + val PYTHON_GATEWAY_CONNECT_TIMEOUT = ConfigBuilder("spark.python.gateway.connectTimeout") + .version("3.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 548fd1b07ddc..cc3803783811 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration import scala.util.Try import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python.PYTHON_GATEWAY_CONNECT_TIMEOUT import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,11 +36,11 @@ import org.apache.spark.util.{ThreadUtils, Utils} * handling one batch of data, with authentication and error handling. * * The socket server can only accept one connection, or close if no connection - * in 15 seconds. + * in configurable amount of seconds (default 15). */ private[spark] abstract class SocketAuthServer[T]( authHelper: SocketAuthHelper, - threadName: String) { + threadName: String) extends Logging { def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) def this(threadName: String) = this(SparkEnv.get, threadName) @@ -46,17 +48,23 @@ private[spark] abstract class SocketAuthServer[T]( private val promise = Promise[T]() private def startServer(): (Int, String) = { + logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + // Close the socket if no connection in configured seconds + val timeout = SparkEnv.get.conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT).toInt * 1000 + logTrace(s"Setting timeout to $timeout ms") + serverSocket.setSoTimeout(timeout) new Thread(threadName) { setDaemon(true) override def run(): Unit = { var sock: Socket = null try { + logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}") sock = serverSocket.accept() + logTrace(s"Connection accepted from port ${sock.getLocalPort}") authHelper.authClient(sock) + logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) } finally { JavaUtils.closeQuietly(serverSocket) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index eafa5d90f9ff..ff3b56a9c9ad 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(15) + sock.settimeout(int(os.environ.get("SPARK_GATEWAY_CONNECT_TIMEOUT", 15)) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret) From 17d357b92038f32d9ad36e735f8e732b8b324bba Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 16 Nov 2020 16:32:59 +0100 Subject: [PATCH 02/14] compile fix --- python/pyspark/java_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index ff3b56a9c9ad..1aba006a374a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(int(os.environ.get("SPARK_GATEWAY_CONNECT_TIMEOUT", 15)) + sock.settimeout(int(os.environ.get("SPARK_GATEWAY_CONNECT_TIMEOUT", 15))) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret) From 424be6470ad352707bd1b9641ea870d9e3ca922c Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 17 Nov 2020 11:08:37 +0100 Subject: [PATCH 03/14] PythonRDDSuite fix --- .../spark/api/python/PythonRDDSuite.scala | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index aae5fb002e1e..e6c28a26614f 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -30,8 +30,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.mockito.Mockito.mock -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} @@ -39,6 +40,8 @@ import org.apache.spark.util.Utils class PythonRDDSuite extends SparkFunSuite with LocalSparkContext { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + var tempDir: File = _ override def beforeAll(): Unit = { @@ -76,12 +79,22 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext { } test("python server error handling") { - val authHelper = new SocketAuthHelper(new SparkConf()) - val errorServer = new ExceptionPythonServer(authHelper) - val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) - authHelper.authToServer(client) - val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } - assert(ex.getCause().getMessage().contains("exception within handleConnection")) + val savedSparkEnv = SparkEnv.get + try { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + + val authHelper = new SocketAuthHelper(conf) + val errorServer = new ExceptionPythonServer(authHelper) + val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) + authHelper.authToServer(client) + val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } + assert(ex.getCause().getMessage().contains("exception within handleConnection")) + } finally { + SparkEnv.set(savedSparkEnv) + } } class ExceptionPythonServer(authHelper: SocketAuthHelper) From d9feed8a56087b28c18e7df0bd4978c4c5390312 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 17 Nov 2020 15:18:04 +0100 Subject: [PATCH 04/14] Some cleanup --- .../scala/org/apache/spark/security/SocketAuthServer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index cc3803783811..a99667cfedc8 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -51,9 +51,9 @@ private[spark] abstract class SocketAuthServer[T]( logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Close the socket if no connection in configured seconds - val timeout = SparkEnv.get.conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT).toInt * 1000 - logTrace(s"Setting timeout to $timeout ms") - serverSocket.setSoTimeout(timeout) + val timeout = SparkEnv.get.conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT).toInt + logTrace(s"Setting timeout to $timeout sec") + serverSocket.setSoTimeout(timeout * 1000) new Thread(threadName) { setDaemon(true) From 6e8e194bdc907ba47a0dcb85de1fc42cdd713863 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 17 Nov 2020 15:30:57 +0100 Subject: [PATCH 05/14] Additional trace --- .../main/scala/org/apache/spark/security/SocketAuthServer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index a99667cfedc8..ee85839f0eca 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -67,6 +67,7 @@ private[spark] abstract class SocketAuthServer[T]( logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) } finally { + logTrace("Closing server") JavaUtils.closeQuietly(serverSocket) JavaUtils.closeQuietly(sock) } From f504af38ac0e855c2a97e8341ea53ea971cfe5b5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 18 Nov 2020 09:14:47 +0100 Subject: [PATCH 06/14] Partial review fix --- .../main/scala/org/apache/spark/api/python/PythonRunner.scala | 4 ++-- .../main/scala/org/apache/spark/internal/config/Python.scala | 3 ++- .../scala/org/apache/spark/security/SocketAuthServer.scala | 4 ++-- python/pyspark/java_gateway.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7da1f50a0b3e..3a217a2753f7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -85,7 +85,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) - protected val gatewayConnectTimeout = conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) // All the Python functions should have the same exec, version and envvars. @@ -141,7 +141,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (workerMemoryMb.isDefined) { envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) } - envVars.put("SPARK_GATEWAY_CONNECT_TIMEOUT", gatewayConnectTimeout.toString) + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index b3542e87c634..348a33e129d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -51,7 +51,8 @@ private[spark] object Python { .bytesConf(ByteUnit.MiB) .createOptional - val PYTHON_GATEWAY_CONNECT_TIMEOUT = ConfigBuilder("spark.python.gateway.connectTimeout") + val PYTHON_AUTH_SOCKET_TIMEOUT = ConfigBuilder("spark.python.authenticate.socketTimeout") + .internal() .version("3.1.0") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("15s") diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index ee85839f0eca..99b772308542 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.Python.PYTHON_GATEWAY_CONNECT_TIMEOUT +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.{ThreadUtils, Utils} @@ -51,7 +51,7 @@ private[spark] abstract class SocketAuthServer[T]( logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Close the socket if no connection in configured seconds - val timeout = SparkEnv.get.conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT).toInt + val timeout = SparkEnv.get.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt logTrace(s"Setting timeout to $timeout sec") serverSocket.setSoTimeout(timeout * 1000) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 1aba006a374a..fe2e326dff8b 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) - sock.settimeout(int(os.environ.get("SPARK_GATEWAY_CONNECT_TIMEOUT", 15))) + sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) sock.connect(sa) sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) _do_server_auth(sockfile, auth_secret) From 15955810e50d63fdd1b747bbfb0fdfe7859f6833 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Nov 2020 15:27:28 +0100 Subject: [PATCH 07/14] Driver side coverage --- .../main/scala/org/apache/spark/api/python/PythonUtils.scala | 4 ++++ .../scala/org/apache/spark/security/SocketAuthServer.scala | 2 +- python/pyspark/context.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 527d0d6d3a48..33849f6fcb65 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -85,4 +85,8 @@ private[spark] object PythonUtils { def getBroadcastThreshold(sc: JavaSparkContext): Long = { sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) } + + def getPythonAuthSocketTimeout(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT) + } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 99b772308542..6f5f03f6c945 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -62,7 +62,7 @@ private[spark] abstract class SocketAuthServer[T]( try { logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}") sock = serverSocket.accept() - logTrace(s"Connection accepted from port ${sock.getLocalPort}") + logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}") authHelper.authClient(sock) logTrace("Client authenticated") promise.complete(Try(handleConnection(sock))) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9c9e3f4b3c88..34d677c09fb7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # data via a socket. # scala's mangled names w/ $ in them require special treatment. self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) + os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] From a67acd7c23e3ed21e59aaa2c412f271219bd3660 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Fri, 20 Nov 2020 10:24:04 +0100 Subject: [PATCH 08/14] Style fix --- python/pyspark/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 34d677c09fb7..1bd5961e0525 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,7 +222,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # data via a socket. # scala's mangled names w/ $ in them require special treatment. self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) - os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) + os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \ + str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] From 0f9e587ac277f6da329da26d2db23c7d2ebbf768 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 23 Nov 2020 09:51:39 +0900 Subject: [PATCH 09/14] Update core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala --- .../test/scala/org/apache/spark/api/python/PythonRDDSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index e6c28a26614f..63eb02d1a606 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.mockito.Mockito.mock import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.api.java.JavaSparkContext From cd2d5951ba4126a11fc99419d7d1e4ff28b31122 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 23 Nov 2020 09:51:51 +0900 Subject: [PATCH 10/14] Update core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala --- .../test/scala/org/apache/spark/api/python/PythonRDDSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 63eb02d1a606..b171cd84e43b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} From d9f0a1bfb08582acaa372f2d1a609c34a115401e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 23 Nov 2020 09:52:00 +0900 Subject: [PATCH 11/14] Update core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala --- .../main/scala/org/apache/spark/security/SocketAuthServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 6f5f03f6c945..934dabee9499 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -50,7 +50,7 @@ private[spark] abstract class SocketAuthServer[T]( private def startServer(): (Int, String) = { logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in configured seconds + // Close the socket if no connection in the configured seconds val timeout = SparkEnv.get.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt logTrace(s"Setting timeout to $timeout sec") serverSocket.setSoTimeout(timeout * 1000) From 2913fb1eadb8a43cd7f2c3e429929740ff2e9841 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 23 Nov 2020 09:52:09 +0900 Subject: [PATCH 12/14] Update core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala --- .../spark/api/python/PythonRDDSuite.scala | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index b171cd84e43b..b8f3be0387d9 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -78,22 +78,12 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext { } test("python server error handling") { - val savedSparkEnv = SparkEnv.get - try { - val conf = new SparkConf() - val env = mock(classOf[SparkEnv]) - doReturn(conf).when(env).conf - SparkEnv.set(env) - - val authHelper = new SocketAuthHelper(conf) - val errorServer = new ExceptionPythonServer(authHelper) - val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) - authHelper.authToServer(client) - val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } - assert(ex.getCause().getMessage().contains("exception within handleConnection")) - } finally { - SparkEnv.set(savedSparkEnv) - } + val authHelper = new SocketAuthHelper(new SparkConf()) + val errorServer = new ExceptionPythonServer(authHelper) + val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) + authHelper.authToServer(client) + val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } + assert(ex.getCause().getMessage().contains("exception within handleConnection")) } class ExceptionPythonServer(authHelper: SocketAuthHelper) From ef137b6664f5a456001af9e70f75d0a5b3e8aad3 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 23 Nov 2020 09:52:19 +0900 Subject: [PATCH 13/14] Update core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala --- .../test/scala/org/apache/spark/api/python/PythonRDDSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index b8f3be0387d9..aae5fb002e1e 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -39,8 +39,6 @@ import org.apache.spark.util.Utils class PythonRDDSuite extends SparkFunSuite with LocalSparkContext { - private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) - var tempDir: File = _ override def beforeAll(): Unit = { From 363f3bb07ccd34cf3d70303a5512595cc59b603b Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Mon, 23 Nov 2020 11:31:52 +0900 Subject: [PATCH 14/14] Fix test in PythonRDDSuite --- .../main/scala/org/apache/spark/security/SocketAuthHelper.scala | 2 +- .../main/scala/org/apache/spark/security/SocketAuthServer.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index dbcb37690533..f800553c5388 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils * * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. */ -private[spark] class SocketAuthHelper(conf: SparkConf) { +private[spark] class SocketAuthHelper(val conf: SparkConf) { val secret = Utils.createSecret(conf) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index 934dabee9499..35990b5a5928 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -51,7 +51,7 @@ private[spark] abstract class SocketAuthServer[T]( logTrace("Creating listening socket") val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Close the socket if no connection in the configured seconds - val timeout = SparkEnv.get.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt + val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt logTrace(s"Setting timeout to $timeout sec") serverSocket.setSoTimeout(timeout * 1000)