From 0f59a6a3a4d2a307131e3d17e65938529007bafc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 27 Mar 2024 09:42:17 +0000 Subject: [PATCH] squash --- .../api/python/PythonWorkerFactory.scala | 29 +++++++++++++------ python/pyspark/tests/test_worker.py | 16 ++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4ae1e3c92311..eb740b72987c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -77,7 +77,7 @@ private[spark] class PythonWorkerFactory( @GuardedBy("self") private var daemonPort: Int = 0 @GuardedBy("self") - private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Long]() + private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, ProcessHandle]() @GuardedBy("self") private val idleWorkers = new mutable.Queue[PythonWorker]() @GuardedBy("self") @@ -95,10 +95,20 @@ private[spark] class PythonWorkerFactory( def create(): (PythonWorker, Option[Long]) = { if (useDaemon) { self.synchronized { - if (idleWorkers.nonEmpty) { + // Pull from idle workers until we one that is alive, otherwise create a new one. + while (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() - worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE) - return (worker, daemonWorkers.get(worker)) + val workerHandle = daemonWorkers(worker) + if (workerHandle.isAlive()) { + try { + worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE) + return (worker, Some(workerHandle.pid())) + } catch { + case c: CancelledKeyException => /* pass */ + } + } + logWarning(s"Worker ${worker} process from idle queue is dead, discarding.") + stopWorker(worker) } } createThroughDaemon() @@ -121,15 +131,16 @@ private[spark] class PythonWorkerFactory( if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } - + val processHandle = ProcessHandle.of(pid).orElseThrow( + () => new IllegalStateException("Python daemon failed to launch worker.") + ) authHelper.authToServer(socketChannel.socket()) socketChannel.configureBlocking(false) val selector = Selector.open() val selectionKey = socketChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE) val worker = PythonWorker(socketChannel, selector, selectionKey) - - daemonWorkers.put(worker, pid) + daemonWorkers.put(worker, processHandle) (worker, Some(pid)) } @@ -391,10 +402,10 @@ private[spark] class PythonWorkerFactory( self.synchronized { if (useDaemon) { if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => + daemonWorkers.get(worker).foreach { processHandle => // tell daemon to kill worker by pid val output = new DataOutputStream(daemon.getOutputStream) - output.writeLong(pid) + output.writeLong(processHandle.pid()) output.flush() daemon.getOutputStream.flush() } diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index e51b030f4574..3961997120ed 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -16,6 +16,7 @@ # limitations under the License. # import os +import signal import sys import tempfile import threading @@ -256,6 +257,21 @@ def conf(cls): return _conf +class WorkerPoolCrashTest(PySparkTestCase): + def test_worker_crash(self): + # SPARK-47565: Kill a worker that is currently idling + rdd = self.sc.parallelize(range(20), 4) + # first ensure that workers are reused + worker_pids1 = set(rdd.map(lambda x: os.getpid()).collect()) + worker_pids2 = set(rdd.map(lambda x: os.getpid()).collect()) + self.assertEqual(worker_pids1, worker_pids2) + for pid in list(worker_pids1)[1:]: # kill all workers except for one + os.kill(pid, signal.SIGTERM) + # give things a moment to settle + time.sleep(5) + rdd.map(lambda x: os.getpid()).collect() + + if __name__ == "__main__": import unittest from pyspark.tests.test_worker import * # noqa: F401