Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val
/**
* Ensure the RDD is fully cached so the partitions can be recovered later.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
protected[spark] override def doCheckpoint(): CheckpointRDD[T] = {
val level = rdd.getStorageLevel

// Assume storage level uses disk; otherwise memory eviction may cause data loss
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,7 @@ abstract class RDD[T: ClassTag](
}

// Avoid handling doCheckpoint multiple times to prevent excessive recursion
@transient private var doCheckpointCalled = false
@transient private[spark] var doCheckpointCalled = false

/**
* Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private
import CheckpointState._

// The checkpoint state of the associated RDD.
protected var cpState = Initialized
protected[spark] var cpState = Initialized

// The RDD that contains our checkpointed data
private var cpRDD: Option[CheckpointRDD[T]] = None
private[spark] var cpRDD: Option[CheckpointRDD[T]] = None

// TODO: are we sure we need to use a global lock in the following methods?

Expand Down Expand Up @@ -88,7 +88,7 @@ private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private
* Subclasses should override this method to define custom checkpointing behavior.
* @return the checkpoint RDD created in the process.
*/
protected def doCheckpoint(): CheckpointRDD[T]
protected[spark] def doCheckpoint(): CheckpointRDD[T]

/**
* Return the RDD that contains our checkpointed data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* Materialize this RDD and write its content to a reliable DFS.
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
protected[spark] override def doCheckpoint(): CheckpointRDD[T] = {
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)

// Optionally clean our checkpoint files if the reference is out of scope
Expand Down
27 changes: 18 additions & 9 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
Expand Down Expand Up @@ -1016,15 +1016,24 @@ class DAGScheduler(
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
case stage: ShuffleMapStage =>
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
var taskBinaryBytes: Array[Byte] = null
// taskBinaryBytes and partitions are both effected by the checkpoint status. We need
// this synchronization in case another concurrent job is checkpointing this RDD, so we get a
// consistent view of both variables.
RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}

partitions = stage.rdd.partitions
}

taskBinary = sc.broadcast(taskBinaryBytes)
Expand All @@ -1049,7 +1058,7 @@ class DAGScheduler(
stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Expand All @@ -1059,7 +1068,7 @@ class DAGScheduler(
case stage: ResultStage =>
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = stage.rdd.partitions(p)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@

package org.apache.spark.scheduler

import java.io.File
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal

import org.mockito.Mockito._
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rdd.RDD
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.{CheckpointState, RDD, RDDCheckpointData}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
Expand Down Expand Up @@ -96,6 +101,26 @@ class MyRDD(
override def toString: String = "DAGSchedulerSuiteRDD " + id
}

/** Wrapped rdd partition. */
class WrappedPartition(val partition: Partition) extends Partition {
def index: Int = partition.index
}

/**
* An RDD with a particular defined Partition which is WrappedPartition.
* The compute method will cast the split to WrappedPartition. The cast operation will be
* used in this test suite.
*/
class WrappedRDD(parent: RDD[Int]) extends RDD[Int](parent) {
protected def getPartitions: Array[Partition] = {
parent.partitions.map(p => new WrappedPartition(p))
}

def compute(split: Partition, context: TaskContext): Iterator[Int] = {
parent.compute(split.asInstanceOf[WrappedPartition].partition, context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is the key point for WrppedPartition and WrappedRDD, please give comments for explaining your intention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, i will work on this.

}
}

class DAGSchedulerSuiteDummyException extends Exception

class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits {
Expand Down Expand Up @@ -2399,6 +2424,121 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
}

/**
* In this test, we simulate the scene in concurrent jobs using the same
* rdd which is marked to do checkpoint:
* Job one has already finished the spark job, and start the process of doCheckpoint;
* Job two is submitted, and submitMissingTasks is called.
* In submitMissingTasks, if taskSerialization is called before doCheckpoint is done,
* while part calculates from stage.rdd.partitions is called after doCheckpoint is done,
* we may get a ClassCastException when execute the task because of some rdd will do
* Partition cast.
*
* With this test case, just want to indicate that we should do taskSerialization and
* part calculate in submitMissingTasks with the same rdd checkpoint status.
*/
test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ivoson -- I'm really sorry but I only just realized that this "test" is really just a repro, and it passes both before and after the actual code changes, since you've replicated the internal logic we're fixing. As such, I don't think its actually useful as a test case -- perhaps it should get added to the jira as a repro.

I appreciate the work that went into writing this as it helped make the issue clear to me. I am not sure if there is a good way to test this. If we can't come up with anything, we should just commit your actual fix, but give me a day or two to think about it ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@squito thanks for reply. I understand this, technically it may not be a UT case, just simulate the scene with exception. I also wonder if there is a good way to test this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ivoson -- I haven't come up with a better way to test this, so I think for now you should

(1) change the PR to only include the changes to the DAGScheduler (also undo the protected[spark] changes elsewhere)
(2) put this repro on the jira as its a pretty good for showing whats going on.

if we come up with a way to test it, we can always do that later on.

thanks and sorry for the back and forth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @squito , it's fine. The pr and jira have been updated. Thanks for your patient and review.

// set checkpointDir.
val checkpointDir = Utils.createTempDir()
sc.setCheckpointDir(checkpointDir.toString)

// Semaphores to control the process sequence for the two threads below.
val doCheckpointStarted = new Semaphore(0)
val taskBinaryBytesFinished = new Semaphore(0)
val checkpointStateUpdated = new Semaphore(0)

val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4))
rdd.checkpoint()

val checkpointRunnable = new Runnable {
override def run() = {
// Simulate what RDD.doCheckpoint() does here.
rdd.doCheckpointCalled = true
val checkpointData = rdd.checkpointData.get
RDDCheckpointData.synchronized {
if (checkpointData.cpState == CheckpointState.Initialized) {
checkpointData.cpState = CheckpointState.CheckpointingInProgress
}
}

val newRDD = checkpointData.doCheckpoint()

// Release doCheckpointStarted after job triggered in checkpoint finished, so
// that taskBinary serialization can start.
doCheckpointStarted.release()
// Wait until taskBinary serialization finished in submitMissingTasksThread.
taskBinaryBytesFinished.acquire()

// Update our state and truncate the RDD lineage.
RDDCheckpointData.synchronized {
checkpointData.cpRDD = Some(newRDD)
checkpointData.cpState = CheckpointState.Checkpointed
rdd.markCheckpointed()
}
checkpointStateUpdated.release()
}
}

val submitMissingTasksRunnable = new Runnable {
override def run() = {
// Simulate the process of submitMissingTasks.
// Wait until doCheckpoint job running finished, but checkpoint status not changed.
doCheckpointStarted.acquire()

val ser = SparkEnv.get.closureSerializer.newInstance()

// Simulate task serialization while submitMissingTasks.
// Task serialized with rdd checkpoint not finished.
val cleanedFunc = sc.clean(Utils.getIteratorSize _)
val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it)
val taskBinaryBytes = JavaUtils.bufferToArray(
ser.serialize((rdd, func): AnyRef))
// Because partition calculate is in a synchronized block, so in the fixed code
// partition is calculated here.
val correctPart = rdd.partitions(0)

// Release taskBinaryBytesFinished so changing checkpoint status to Checkpointed will
// be done in checkpointThread.
taskBinaryBytesFinished.release()
// Wait until checkpoint status changed to Checkpointed in checkpointThread.
checkpointStateUpdated.acquire()

// Now we're done simulating the interleaving that might happen within the scheduler,
// we'll check to make sure the final state is OK by simulating a couple steps that
// normally happen on the executor.
// Part calculated with rdd checkpoint already finished.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment above this:

Now we're done simulating the interleaving that might happen within the scheduler -- we'll check to make sure the final state is OK by simulating a couple steps that normally happen on the executor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise, it is really helpful for understanding, will update this.

val errPart = rdd.partitions(0)

// TaskBinary will be deserialized when run task in executor.
val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, Iterator[Int]) => Unit)](
ByteBuffer.wrap(taskBinaryBytes), Thread.currentThread.getContextClassLoader)

val taskContext = mock(classOf[TaskContext])
doNothing().when(taskContext).killTaskIfInterrupted()

// Make sure our test case is setup correctly -- we expect a ClassCastException here
// if we use the rdd.partitions after checkpointing was done, but our binary bytes is
// from before it finished.
intercept[ClassCastException] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this not a "test", this just a "reproduce" for the problem you want to fix. We should prove your code added in DAGScheduler.scala can fix that problem and with the original code base, a ClassCastException raised.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a reproduce case, i will fix this.

// Triggered when runTask in executor.
taskRdd.iterator(errPart, taskContext)
}

// Execute successfully with correctPart.
taskRdd.iterator(correctPart, taskContext)
}
}

try {
new Thread(checkpointRunnable).start()
val submitMissingTasksThread = new Thread(submitMissingTasksRunnable)
submitMissingTasksThread.start()
submitMissingTasksThread.join()
} finally {
Utils.deleteRecursively(checkpointDir)
}
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down