@@ -77,7 +77,7 @@ private[spark] class PythonWorkerFactory(
7777 @ GuardedBy (" self" )
7878 private var daemonPort : Int = 0
7979 @ GuardedBy (" self" )
80- private val daemonWorkers = new mutable.WeakHashMap [PythonWorker , Long ]()
80+ private val daemonWorkers = new mutable.WeakHashMap [PythonWorker , ProcessHandle ]()
8181 @ GuardedBy (" self" )
8282 private val idleWorkers = new mutable.Queue [PythonWorker ]()
8383 @ GuardedBy (" self" )
@@ -95,10 +95,20 @@ private[spark] class PythonWorkerFactory(
9595 def create (): (PythonWorker , Option [Long ]) = {
9696 if (useDaemon) {
9797 self.synchronized {
98- if (idleWorkers.nonEmpty) {
98+ // Pull from idle workers until we one that is alive, otherwise create a new one.
99+ while (idleWorkers.nonEmpty) {
99100 val worker = idleWorkers.dequeue()
100- worker.selectionKey.interestOps(SelectionKey .OP_READ | SelectionKey .OP_WRITE )
101- return (worker, daemonWorkers.get(worker))
101+ val workerHandle = daemonWorkers(worker)
102+ if (workerHandle.isAlive()) {
103+ try {
104+ worker.selectionKey.interestOps(SelectionKey .OP_READ | SelectionKey .OP_WRITE )
105+ return (worker, Some (workerHandle.pid()))
106+ } catch {
107+ case c : CancelledKeyException => /* pass */
108+ }
109+ }
110+ logWarning(s " Worker ${worker} process from idle queue is dead, discarding. " )
111+ stopWorker(worker)
102112 }
103113 }
104114 createThroughDaemon()
@@ -121,15 +131,16 @@ private[spark] class PythonWorkerFactory(
121131 if (pid < 0 ) {
122132 throw new IllegalStateException (" Python daemon failed to launch worker with code " + pid)
123133 }
124-
134+ val processHandle = ProcessHandle .of(pid).orElseThrow(
135+ () => new IllegalStateException (" Python daemon failed to launch worker." )
136+ )
125137 authHelper.authToServer(socketChannel.socket())
126138 socketChannel.configureBlocking(false )
127139 val selector = Selector .open()
128140 val selectionKey = socketChannel.register(selector,
129141 SelectionKey .OP_READ | SelectionKey .OP_WRITE )
130142 val worker = PythonWorker (socketChannel, selector, selectionKey)
131-
132- daemonWorkers.put(worker, pid)
143+ daemonWorkers.put(worker, processHandle)
133144 (worker, Some (pid))
134145 }
135146
@@ -391,10 +402,10 @@ private[spark] class PythonWorkerFactory(
391402 self.synchronized {
392403 if (useDaemon) {
393404 if (daemon != null ) {
394- daemonWorkers.get(worker).foreach { pid =>
405+ daemonWorkers.get(worker).foreach { processHandle =>
395406 // tell daemon to kill worker by pid
396407 val output = new DataOutputStream (daemon.getOutputStream)
397- output.writeLong(pid)
408+ output.writeLong(processHandle. pid() )
398409 output.flush()
399410 daemon.getOutputStream.flush()
400411 }
0 commit comments