Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,90 @@ private[spark] class UninterruptibleThread(
this(null, name)
}

/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
private class UninterruptibleLock {
/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false

/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false
/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false

/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false
/**
* Indicates that we should wait for interrupt() call before proceeding.
*/
@GuardedBy("uninterruptibleLock")
private var awaitInterruptThread = false

/**
* Set [[uninterruptible]] to given value and returns the previous value.
*/
def getAndSetUninterruptible(value: Boolean): Boolean = synchronized {
val uninterruptible = this.uninterruptible
this.uninterruptible = value
uninterruptible
}

def setShouldInterruptThread(value: Boolean): Unit = synchronized {
shouldInterruptThread = value
}

def setAwaitInterruptThread(value: Boolean): Unit = synchronized {
awaitInterruptThread = value
}

/**
* Is call to [[java.lang.Thread.interrupt()]] pending
*/
def isInterruptPending: Boolean = synchronized {
// Clear the interrupted status if it's set.
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
// wait for super.interrupt() to be called
!shouldInterruptThread && awaitInterruptThread
}

/**
* Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to
* recover interrupt state if necessary
*/
def recoverInterrupt(): Unit = synchronized {
uninterruptible = false
if (shouldInterruptThread) {
shouldInterruptThread = false
// Recover the interrupted status
UninterruptibleThread.super.interrupt()
}
}

/**
* Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread
* @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]]
* is true) and no concurrent [[interrupt()]] call, otherwise false
*/
def isInterruptible: Boolean = synchronized {
shouldInterruptThread = uninterruptible
// as we are releasing uninterruptibleLock before calling super.interrupt() there is a
// possibility that runUninterruptibly() would be called after lock is released but before
// super.interrupt() is called. In this case to prevent runUninterruptibly() from being
// interrupted, we use awaitInterruptThread flag. We need to set it only if
// runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and
// there is no other threads that called interrupt (awaitInterruptThread is already true)
if (!shouldInterruptThread && !awaitInterruptThread) {
awaitInterruptThread = true
true
} else {
false
}
}
}

/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new UninterruptibleLock

/**
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
Expand All @@ -63,27 +132,23 @@ private[spark] class UninterruptibleThread(
s"Expected: $this but was ${Thread.currentThread()}")
}

if (uninterruptibleLock.synchronized { uninterruptible }) {
if (uninterruptibleLock.getAndSetUninterruptible(true)) {
// We are already in the uninterruptible status. So just run "f" and return
return f
}

uninterruptibleLock.synchronized {
// Clear the interrupted status if it's set.
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
uninterruptible = true
while (uninterruptibleLock.isInterruptPending) {
try {
Thread.sleep(100)
} catch {
case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true)
}
}

try {
f
} finally {
uninterruptibleLock.synchronized {
uninterruptible = false
if (shouldInterruptThread) {
// Recover the interrupted status
super.interrupt()
shouldInterruptThread = false
}
}
uninterruptibleLock.recoverInterrupt()
}
}

Expand All @@ -92,11 +157,11 @@ private[spark] class UninterruptibleThread(
* interrupted until it enters into the interruptible status.
*/
override def interrupt(): Unit = {
uninterruptibleLock.synchronized {
if (uninterruptible) {
shouldInterruptThread = true
} else {
if (uninterruptibleLock.isInterruptible) {
try {
super.interrupt()
} finally {
uninterruptibleLock.setAwaitInterruptThread(false)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.util

import java.nio.channels.spi.AbstractInterruptibleChannel
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.util.Random
Expand Down Expand Up @@ -115,6 +116,46 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
assert(interruptStatusBeforeExit)
}

test("no runUninterruptibly") {
@volatile var hasInterruptedException = false
val latch = new CountDownLatch(1)
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
latch.countDown()
hasInterruptedException = sleep(1)
}
}
t.start()
latch.await(10, TimeUnit.SECONDS)
t.interrupt()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks that Java 8 behaves differently when interrupt() is called on not started thread.

t.join()
assert(hasInterruptedException === true)
}

test("SPARK-51821 uninterruptibleLock deadlock") {
val latch = new CountDownLatch(1)
val task = new UninterruptibleThread("task thread") {
override def run(): Unit = {
val channel = new AbstractInterruptibleChannel() {
override def implCloseChannel(): Unit = {
begin()
latch.countDown()
try {
Thread.sleep(Long.MaxValue)
} catch {
case _: InterruptedException => Thread.currentThread().interrupt()
}
}
}
channel.close()
}
}
task.start()
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
task.interrupt()
task.join()
}

test("stress test") {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
Expand Down Expand Up @@ -148,9 +189,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
}
}
t.start()
for (i <- 0 until 400) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
val threads = new Array[Thread](10)
for (j <- 0 until 10) {
threads(j) = new Thread() {
override def run(): Unit = {
for (i <- 0 until 400) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
}
}
}
threads(j).start()
}
for (j <- 0 until 10) {
threads(j).join()
}
t.join()
assert(hasInterruptedException === false)
Expand Down