Skip to content

Commit 0f59a6a

Browse files
author
Ubuntu
committed
squash
1 parent 3e83432 commit 0f59a6a

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private[spark] class PythonWorkerFactory(
7777
@GuardedBy("self")
7878
private var daemonPort: Int = 0
7979
@GuardedBy("self")
80-
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Long]()
80+
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, ProcessHandle]()
8181
@GuardedBy("self")
8282
private val idleWorkers = new mutable.Queue[PythonWorker]()
8383
@GuardedBy("self")
@@ -95,10 +95,20 @@ private[spark] class PythonWorkerFactory(
9595
def create(): (PythonWorker, Option[Long]) = {
9696
if (useDaemon) {
9797
self.synchronized {
98-
if (idleWorkers.nonEmpty) {
98+
// Pull from idle workers until we one that is alive, otherwise create a new one.
99+
while (idleWorkers.nonEmpty) {
99100
val worker = idleWorkers.dequeue()
100-
worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE)
101-
return (worker, daemonWorkers.get(worker))
101+
val workerHandle = daemonWorkers(worker)
102+
if (workerHandle.isAlive()) {
103+
try {
104+
worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE)
105+
return (worker, Some(workerHandle.pid()))
106+
} catch {
107+
case c: CancelledKeyException => /* pass */
108+
}
109+
}
110+
logWarning(s"Worker ${worker} process from idle queue is dead, discarding.")
111+
stopWorker(worker)
102112
}
103113
}
104114
createThroughDaemon()
@@ -121,15 +131,16 @@ private[spark] class PythonWorkerFactory(
121131
if (pid < 0) {
122132
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
123133
}
124-
134+
val processHandle = ProcessHandle.of(pid).orElseThrow(
135+
() => new IllegalStateException("Python daemon failed to launch worker.")
136+
)
125137
authHelper.authToServer(socketChannel.socket())
126138
socketChannel.configureBlocking(false)
127139
val selector = Selector.open()
128140
val selectionKey = socketChannel.register(selector,
129141
SelectionKey.OP_READ | SelectionKey.OP_WRITE)
130142
val worker = PythonWorker(socketChannel, selector, selectionKey)
131-
132-
daemonWorkers.put(worker, pid)
143+
daemonWorkers.put(worker, processHandle)
133144
(worker, Some(pid))
134145
}
135146

@@ -391,10 +402,10 @@ private[spark] class PythonWorkerFactory(
391402
self.synchronized {
392403
if (useDaemon) {
393404
if (daemon != null) {
394-
daemonWorkers.get(worker).foreach { pid =>
405+
daemonWorkers.get(worker).foreach { processHandle =>
395406
// tell daemon to kill worker by pid
396407
val output = new DataOutputStream(daemon.getOutputStream)
397-
output.writeLong(pid)
408+
output.writeLong(processHandle.pid())
398409
output.flush()
399410
daemon.getOutputStream.flush()
400411
}

python/pyspark/tests/test_worker.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
#
1818
import os
19+
import signal
1920
import sys
2021
import tempfile
2122
import threading
@@ -256,6 +257,21 @@ def conf(cls):
256257
return _conf
257258

258259

260+
class WorkerPoolCrashTest(PySparkTestCase):
261+
def test_worker_crash(self):
262+
# SPARK-47565: Kill a worker that is currently idling
263+
rdd = self.sc.parallelize(range(20), 4)
264+
# first ensure that workers are reused
265+
worker_pids1 = set(rdd.map(lambda x: os.getpid()).collect())
266+
worker_pids2 = set(rdd.map(lambda x: os.getpid()).collect())
267+
self.assertEqual(worker_pids1, worker_pids2)
268+
for pid in list(worker_pids1)[1:]: # kill all workers except for one
269+
os.kill(pid, signal.SIGTERM)
270+
# give things a moment to settle
271+
time.sleep(5)
272+
rdd.map(lambda x: os.getpid()).collect()
273+
274+
259275
if __name__ == "__main__":
260276
import unittest
261277
from pyspark.tests.test_worker import * # noqa: F401

0 commit comments

Comments
 (0)