Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](

private val conf = SparkEnv.get.conf
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val gatewayConnectTimeout = conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)

// All the Python functions should have the same exec, version and envvars.
Expand Down Expand Up @@ -140,6 +141,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (workerMemoryMb.isDefined) {
envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString)
}
envVars.put("SPARK_GATEWAY_CONNECT_TIMEOUT", gatewayConnectTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ private[spark] object Python {
.version("2.4.0")
.bytesConf(ByteUnit.MiB)
.createOptional

val PYTHON_GATEWAY_CONNECT_TIMEOUT = ConfigBuilder("spark.python.gateway.connectTimeout")
.version("3.1.0")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("15s")
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.concurrent.duration.Duration
import scala.util.Try

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Python.PYTHON_GATEWAY_CONNECT_TIMEOUT
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.{ThreadUtils, Utils}

Expand All @@ -34,29 +36,35 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* handling one batch of data, with authentication and error handling.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
* in configurable amount of seconds (default 15).
*/
private[spark] abstract class SocketAuthServer[T](
authHelper: SocketAuthHelper,
threadName: String) {
threadName: String) extends Logging {

def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
def this(threadName: String) = this(SparkEnv.get, threadName)

private val promise = Promise[T]()

private def startServer(): (Int, String) = {
logTrace("Creating listening socket")
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
// Close the socket if no connection in configured seconds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in configured -> in the configured?

val timeout = SparkEnv.get.conf.get(PYTHON_GATEWAY_CONNECT_TIMEOUT).toInt
logTrace(s"Setting timeout to $timeout sec")
serverSocket.setSoTimeout(timeout * 1000)

new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
logTrace(s"Waiting for connection on port ${serverSocket.getLocalPort}")
sock = serverSocket.accept()
logTrace(s"Connection accepted from port ${sock.getLocalPort}")
authHelper.authClient(sock)
logTrace("Client authenticated")
promise.complete(Try(handleConnection(sock)))
} finally {
JavaUtils.closeQuietly(serverSocket)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.mockito.Mockito.mock

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.util.Utils

class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {

private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*)

var tempDir: File = _

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -76,12 +79,22 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
}

test("python server error handling") {
val authHelper = new SocketAuthHelper(new SparkConf())
val errorServer = new ExceptionPythonServer(authHelper)
val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
authHelper.authToServer(client)
val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
assert(ex.getCause().getMessage().contains("exception within handleConnection"))
val savedSparkEnv = SparkEnv.get
try {
val conf = new SparkConf()
val env = mock(classOf[SparkEnv])
doReturn(conf).when(env).conf
SparkEnv.set(env)
Copy link
Member

@HyukjinKwon HyukjinKwon Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gaborgsomogyi, sorry for an ignorant question. How does this test the current patch? BTW, I personally think it's fine only with manual tests if it's difficult to write a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not testing the newly added code in any way. The test blowed-up with NPE because SparkEnv.get introduced in the server side code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've considered to add automated test for this but measuring time and making assertions based on that is always scary to me. It could be the next flaky test which could make upset the guys.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gaborgsomogyi . If that case, shall we make a separate JIRA and PR for that?

It's not testing the newly added code in any way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this exist in branch-3.0/2.4?

The test blowed-up with NPE because SparkEnv.get introduced in the server side code.


val authHelper = new SocketAuthHelper(conf)
val errorServer = new ExceptionPythonServer(authHelper)
val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
authHelper.authToServer(client)
val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
assert(ex.getCause().getMessage().contains("exception within handleConnection"))
} finally {
SparkEnv.set(savedSparkEnv)
}
}

class ExceptionPythonServer(authHelper: SocketAuthHelper)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def local_connect_and_auth(port, auth_secret):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
sock.settimeout(15)
sock.settimeout(int(os.environ.get("SPARK_GATEWAY_CONNECT_TIMEOUT", 15)))
sock.connect(sa)
sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
_do_server_auth(sockfile, auth_secret)
Expand Down