Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.control.NonFatal

import com.google.common.util.concurrent.ThreadFactoryBuilder

import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -84,7 +86,20 @@ private[spark] class Executor(
}

// Start worker thread pool
private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
private val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("Executor task launch worker-%d")
.setThreadFactory(new ThreadFactory {
override def newThread(r: Runnable): Thread =
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
// will hang forever if some methods are interrupted.
new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly of the codes are copied from ThreadUtils. This one is the only difference that matters.

})
.build()
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
}
private val executorSource = new ExecutorSource(threadPool, executorId)
// Pool used for threads that supervise task killing / cancellation
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy
*
* Note: "runUninterruptibly" should be called only in `this` thread.
*/
private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
private[spark] class UninterruptibleThread(
target: Runnable,
name: String) extends Thread(target, name) {

def this(name: String) {
this(null, name)
}

/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.UninterruptibleThread

class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

Expand Down Expand Up @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(failReason.isInstanceOf[FetchFailed])
}

test("Executor's worker threads should be UninterruptibleThread") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("executor thread test")
.set("spark.ui.enabled", "false")
sc = new SparkContext(conf)
val executorThread = sc.parallelize(Seq(1), 1).map { _ =>
Thread.currentThread.getClass.getName
}.collect().head
assert(executorThread === classOf[UninterruptibleThread].getName)
}

test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.kafka010.KafkaSource._
import org.apache.spark.util.UninterruptibleThread


/**
Expand Down Expand Up @@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private(

case class AvailableOffsetRange(earliest: Long, latest: Long)

private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match {
case ut: UninterruptibleThread =>
ut.runUninterruptibly(body)
case _ =>
logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " +
"It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894")
body
}

/**
* Return the available offset range of the current partition. It's a pair of the earliest offset
* and the latest offset.
*/
def getAvailableOffsetRange(): AvailableOffsetRange = {
def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible {
consumer.seekToBeginning(Set(topicPartition).asJava)
val earliestOffset = consumer.position(topicPartition)
consumer.seekToEnd(Set(topicPartition).asJava)
Expand All @@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private(
offset: Long,
untilOffset: Long,
pollTimeoutMs: Long,
failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
failOnDataLoss: Boolean):
ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible {
require(offset < untilOffset,
s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]")
logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset")
Expand Down