Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
refactored methods to UninterruptibleLock
  • Loading branch information
vrozov committed May 6, 2025
commit b7e64931e7ff02e3e6b6e09a566d4980e7940db3
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,66 @@ private[spark] class UninterruptibleThread(
this(null, name)
}

class UninterruptibleLock {
def awaitInterrupt(): Boolean = synchronized {
private class UninterruptibleLock {
/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false

/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false

/**
* Indicates that we should wait for interrupt() call before proceeding.
*/
@GuardedBy("uninterruptibleLock")
private var awaitInterruptThread = false

/**
* Set [[uninterruptible]] to given value and returns the previous value.
*/
def getAndSetUninterruptible(value: Boolean): Boolean = synchronized {
val uninterruptible = this.uninterruptible
this.uninterruptible = value
uninterruptible
}

def setShouldInterruptThread(value: Boolean): Unit = synchronized {
shouldInterruptThread = value
}

def setAwaitInterruptThread(value: Boolean): Unit = synchronized {
awaitInterruptThread = value
}

/**
* Is call to [[java.lang.Thread.interrupt()]] pending
*/
def isInterruptPending: Boolean = synchronized {
// Clear the interrupted status if it's set.
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
// wait for super.interrupt() to be called
!shouldInterruptThread && awaitInterruptThread
}

/**
* Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to
* recover interrupt state if necessary
*/
def recoverInterrupt(): Unit = synchronized {
uninterruptible = false
if (shouldInterruptThread) {
shouldInterruptThread = false
// Recover the interrupted status
UninterruptibleThread.super.interrupt()
}
}

/**
* Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread
* @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]]
Expand All @@ -68,25 +120,6 @@ private[spark] class UninterruptibleThread(
/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new UninterruptibleLock

/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false

/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false

/**
* Indicates that we should wait for interrupt() call before proceeding.
*/
@GuardedBy("uninterruptibleLock")
private var awaitInterruptThread = false

/**
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
* from `f`.
Expand All @@ -99,35 +132,23 @@ private[spark] class UninterruptibleThread(
s"Expected: $this but was ${Thread.currentThread()}")
}

if (uninterruptibleLock.synchronized { uninterruptible }) {
if (uninterruptibleLock.getAndSetUninterruptible(true)) {
// We are already in the uninterruptible status. So just run "f" and return
return f
}

uninterruptibleLock.synchronized {
uninterruptible = true
}

while (uninterruptibleLock.awaitInterrupt()) {
while (uninterruptibleLock.isInterruptPending) {
try {
Thread.sleep(100)
} catch {
case _: InterruptedException =>
uninterruptibleLock.synchronized { shouldInterruptThread = true }
case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true)
}
}

try {
f
} finally {
uninterruptibleLock.synchronized {
uninterruptible = false
if (shouldInterruptThread) {
// Recover the interrupted status
super.interrupt()
shouldInterruptThread = false
}
}
uninterruptibleLock.recoverInterrupt()
}
}

Expand All @@ -140,9 +161,7 @@ private[spark] class UninterruptibleThread(
try {
super.interrupt()
} finally {
uninterruptibleLock.synchronized {
awaitInterruptThread = false
}
uninterruptibleLock.setAwaitInterruptThread(false)
}
}
}
Expand Down