Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val
/**
* Ensure the RDD is fully cached so the partitions can be recovered later.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
protected[spark] override def doCheckpoint(): CheckpointRDD[T] = {
val level = rdd.getStorageLevel

// Assume storage level uses disk; otherwise memory eviction may cause data loss
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,7 @@ abstract class RDD[T: ClassTag](
}

// Avoid handling doCheckpoint multiple times to prevent excessive recursion
@transient private var doCheckpointCalled = false
@transient private[spark] var doCheckpointCalled = false

/**
* Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private
import CheckpointState._

// The checkpoint state of the associated RDD.
protected var cpState = Initialized
protected[spark] var cpState = Initialized

// The RDD that contains our checkpointed data
private var cpRDD: Option[CheckpointRDD[T]] = None
private[spark] var cpRDD: Option[CheckpointRDD[T]] = None

// TODO: are we sure we need to use a global lock in the following methods?

Expand Down Expand Up @@ -88,7 +88,7 @@ private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private
* Subclasses should override this method to define custom checkpointing behavior.
* @return the checkpoint RDD created in the process.
*/
protected def doCheckpoint(): CheckpointRDD[T]
protected[spark] def doCheckpoint(): CheckpointRDD[T]

/**
* Return the RDD that contains our checkpointed data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* Materialize this RDD and write its content to a reliable DFS.
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
protected[spark] override def doCheckpoint(): CheckpointRDD[T] = {
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)

// Optionally clean our checkpoint files if the reference is out of scope
Expand Down
26 changes: 17 additions & 9 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
Expand Down Expand Up @@ -1016,15 +1016,23 @@ class DAGScheduler(
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
case stage: ShuffleMapStage =>
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
var taskBinaryBytes: Array[Byte] = null
// Add synchronized block to avoid rdd deserialized from taskBinaryBytes has diff checkpoint
// status with the rdd when create ShuffleMapTask or ResultTask.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd reword this a bit:

taskBinaryBytes and partitions are both effected by the checkpoint status. We need this synchronization in case another concurrent job is checkpointing this RDD, so we get a consistent view of both variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise.

RDDCheckpointData.synchronized {
taskBinaryBytes = stage match {
case stage: ShuffleMapStage =>
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}

partitions = stage.rdd.partitions
}

taskBinary = sc.broadcast(taskBinaryBytes)
Expand All @@ -1049,7 +1057,7 @@ class DAGScheduler(
stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Expand All @@ -1059,7 +1067,7 @@ class DAGScheduler(
case stage: ResultStage =>
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = stage.rdd.partitions(p)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@

package org.apache.spark.scheduler

import java.io.File
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal

import org.mockito.Mockito._
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.rdd.RDD
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.{CheckpointState, RDD, RDDCheckpointData}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
Expand Down Expand Up @@ -96,6 +101,26 @@ class MyRDD(
override def toString: String = "DAGSchedulerSuiteRDD " + id
}

/** Wrapped rdd partition. */
class WrappedPartition(val partition: Partition) extends Partition {
def index: Int = partition.index
}

/**
* An RDD with a particular defined Partition which is WrappedPartition.
* The compute method will cast the split to WrappedPartition. The cast operation will be
* used in this test suite.
*/
class WrappedRDD(parent: RDD[Int]) extends RDD[Int](parent) {
protected def getPartitions: Array[Partition] = {
parent.partitions.map(p => new WrappedPartition(p))
}

def compute(split: Partition, context: TaskContext): Iterator[Int] = {
parent.compute(split.asInstanceOf[WrappedPartition].partition, context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is the key point for WrppedPartition and WrappedRDD, please give comments for explaining your intention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, i will work on this.

}
}

class DAGSchedulerSuiteDummyException extends Exception

class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits {
Expand Down Expand Up @@ -2399,6 +2424,115 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
}

/**
* In this test, we simply simulate the scene in concurrent jobs using the same
* rdd which is marked to do checkpoint:
* Job one has already finished the spark job, and start the process of doCheckpoint;
* Job two is submitted, and submitMissingTasks is called.
* In submitMissingTasks, if taskSerialization is called before doCheckpoint is done,
* while part calculates from stage.rdd.partitions is called after doCheckpoint is done,
* we may get a ClassCastException when execute the task because of some rdd will do
* Partition cast.
*
* With this test case, just want to indicate that we should do taskSerialization and
* part calculate in submitMissingTasks with the same rdd checkpoint status.
*/
test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ivoson -- I'm really sorry but I only just realized that this "test" is really just a repro, and it passes both before and after the actual code changes, since you've replicated the internal logic we're fixing. As such, I don't think its actually useful as a test case -- perhaps it should get added to the jira as a repro.

I appreciate the work that went into writing this as it helped make the issue clear to me. I am not sure if there is a good way to test this. If we can't come up with anything, we should just commit your actual fix, but give me a day or two to think about it ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@squito thanks for reply. I understand this, technically it may not be a UT case, just simulate the scene with exception. I also wonder if there is a good way to test this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ivoson -- I haven't come up with a better way to test this, so I think for now you should

(1) change the PR to only include the changes to the DAGScheduler (also undo the protected[spark] changes elsewhere)
(2) put this repro on the jira as its a pretty good for showing whats going on.

if we come up with a way to test it, we can always do that later on.

thanks and sorry for the back and forth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @squito , it's fine. The pr and jira have been updated. Thanks for your patient and review.

// set checkpointDir.
val tempDir = Utils.createTempDir()
val checkpointDir = File.createTempFile("temp", "", tempDir)
checkpointDir.delete()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you make a tempfile for the checkpoint dir and then delete it? why not just checkpointDir = new File(tempDir, "checkpointing")? Or even just checkpointDir = Utils.createTempDir()?

(CheckpointSuite does this so it can call sc.setCheckpointDir, but you're not doing that here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the code again and yes checkpointDir = Utils.createTempDir() is enough for this case, will fix this.

sc.setCheckpointDir(checkpointDir.toString)

// Semaphores to control the process sequence for the two threads below.
val semaphore1 = new Semaphore(0)
val semaphore2 = new Semaphore(0)

val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4))
rdd.checkpoint()

val checkpointRunnable = new Runnable {
override def run() = {
// Simply simulate what RDD.doCheckpoint() do here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove "simply" here and elsewhere in comments. Also "do" -> "does"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix this.

rdd.doCheckpointCalled = true
val checkpointData = rdd.checkpointData.get
RDDCheckpointData.synchronized {
if (checkpointData.cpState == CheckpointState.Initialized) {
checkpointData.cpState = CheckpointState.CheckpointingInProgress
}
}

val newRDD = checkpointData.doCheckpoint()

// Release semaphore1 after job triggered in checkpoint finished, so that taskBinary
// serialization can start.
semaphore1.release()
// Wait until taskBinary serialization finished in submitMissingTasksThread.
semaphore2.acquire()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be a bit easier to follow if you rename your semaphores a bit.

semaphore1 -> doCheckpointStarted
semaphore2 -> taskBinaryBytesFinished

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise. will fix this.


// Update our state and truncate the RDD lineage.
RDDCheckpointData.synchronized {
checkpointData.cpRDD = Some(newRDD)
checkpointData.cpState = CheckpointState.Checkpointed
rdd.markCheckpointed()
}
semaphore1.release()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then this would be another semaphore checkpointStateUpdated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise

}
}

val submitMissingTasksRunnable = new Runnable {
override def run() = {
// Simply simulate the process of submitMissingTasks.
// Wait until doCheckpoint job running finished, but checkpoint status not changed.
semaphore1.acquire()

val ser = SparkEnv.get.closureSerializer.newInstance()

// Simply simulate task serialization while submitMissingTasks.
// Task serialized with rdd checkpoint not finished.
val cleanedFunc = sc.clean(Utils.getIteratorSize _)
val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it)
val taskBinaryBytes = JavaUtils.bufferToArray(
ser.serialize((rdd, func): AnyRef))
// Because partition calculate is in a synchronized block, so in the fixed code
// partition is calculated here.
val correctPart = rdd.partitions(0)

// Release semaphore2 so changing checkpoint status to Checkpointed will be done in
// checkpointThread.
semaphore2.release()
// Wait until checkpoint status changed to Checkpointed in checkpointThread.
semaphore1.acquire()

// Part calculated with rdd checkpoint already finished.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment above this:

Now we're done simulating the interleaving that might happen within the scheduler -- we'll check to make sure the final state is OK by simulating a couple steps that normally happen on the executor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise, it is really helpful for understanding, will update this.

val errPart = rdd.partitions(0)

// TaskBinary will be deserialized when run task in executor.
val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, Iterator[Int]) => Unit)](
ByteBuffer.wrap(taskBinaryBytes), Thread.currentThread.getContextClassLoader)

val taskContext = mock(classOf[TaskContext])
doNothing().when(taskContext).killTaskIfInterrupted()

// ClassCastException is expected with errPart.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit easier to follow if you say

Make sure our test case is setup correctly -- we expect a ClassCastException here if we use the rdd.partitions after checkpointing was done, but our binary bytes is from before it finished.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the advise, it is really helpful for understanding, will update this.

intercept[ClassCastException] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this not a "test", this just a "reproduce" for the problem you want to fix. We should prove your code added in DAGScheduler.scala can fix that problem and with the original code base, a ClassCastException raised.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a reproduce case, i will fix this.

// Triggered when runTask in executor.
taskRdd.iterator(errPart, taskContext)
}

// Execute successfully with correctPart.
taskRdd.iterator(correctPart, taskContext)
}
}

new Thread(checkpointRunnable).start()
val submitMissingTasksThread = new Thread(submitMissingTasksRunnable)
submitMissingTasksThread.start()
submitMissingTasksThread.join()

Utils.deleteRecursively(tempDir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done in a finally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix this.

}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down