-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status #20244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,20 +17,25 @@ | |
|
|
||
| package org.apache.spark.scheduler | ||
|
|
||
| import java.io.File | ||
| import java.nio.ByteBuffer | ||
| import java.util.Properties | ||
| import java.util.concurrent.Semaphore | ||
| import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} | ||
|
|
||
| import scala.annotation.meta.param | ||
| import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} | ||
| import scala.language.reflectiveCalls | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.mockito.Mockito._ | ||
| import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} | ||
| import org.scalatest.time.SpanSugar._ | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.broadcast.BroadcastManager | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.network.util.JavaUtils | ||
| import org.apache.spark.rdd.{CheckpointState, RDD, RDDCheckpointData} | ||
| import org.apache.spark.scheduler.SchedulingMode.SchedulingMode | ||
| import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} | ||
| import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} | ||
|
|
@@ -96,6 +101,26 @@ class MyRDD( | |
| override def toString: String = "DAGSchedulerSuiteRDD " + id | ||
| } | ||
|
|
||
| /** Wrapped rdd partition. */ | ||
| class WrappedPartition(val partition: Partition) extends Partition { | ||
| def index: Int = partition.index | ||
| } | ||
|
|
||
| /** | ||
| * An RDD with a particular defined Partition which is WrappedPartition. | ||
| * The compute method will cast the split to WrappedPartition. The cast operation will be | ||
| * used in this test suite. | ||
| */ | ||
| class WrappedRDD(parent: RDD[Int]) extends RDD[Int](parent) { | ||
| protected def getPartitions: Array[Partition] = { | ||
| parent.partitions.map(p => new WrappedPartition(p)) | ||
| } | ||
|
|
||
| def compute(split: Partition, context: TaskContext): Iterator[Int] = { | ||
| parent.compute(split.asInstanceOf[WrappedPartition].partition, context) | ||
| } | ||
| } | ||
|
|
||
| class DAGSchedulerSuiteDummyException extends Exception | ||
|
|
||
| class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { | ||
|
|
@@ -2399,6 +2424,121 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * In this test, we simulate the scene in concurrent jobs using the same | ||
| * rdd which is marked to do checkpoint: | ||
| * Job one has already finished the spark job, and start the process of doCheckpoint; | ||
| * Job two is submitted, and submitMissingTasks is called. | ||
| * In submitMissingTasks, if taskSerialization is called before doCheckpoint is done, | ||
| * while part calculates from stage.rdd.partitions is called after doCheckpoint is done, | ||
| * we may get a ClassCastException when execute the task because of some rdd will do | ||
| * Partition cast. | ||
| * | ||
| * With this test case, just want to indicate that we should do taskSerialization and | ||
| * part calculate in submitMissingTasks with the same rdd checkpoint status. | ||
| */ | ||
| test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") { | ||
|
||
| // set checkpointDir. | ||
| val checkpointDir = Utils.createTempDir() | ||
| sc.setCheckpointDir(checkpointDir.toString) | ||
|
|
||
| // Semaphores to control the process sequence for the two threads below. | ||
| val doCheckpointStarted = new Semaphore(0) | ||
| val taskBinaryBytesFinished = new Semaphore(0) | ||
| val checkpointStateUpdated = new Semaphore(0) | ||
|
|
||
| val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4)) | ||
| rdd.checkpoint() | ||
|
|
||
| val checkpointRunnable = new Runnable { | ||
| override def run() = { | ||
| // Simulate what RDD.doCheckpoint() does here. | ||
| rdd.doCheckpointCalled = true | ||
| val checkpointData = rdd.checkpointData.get | ||
| RDDCheckpointData.synchronized { | ||
| if (checkpointData.cpState == CheckpointState.Initialized) { | ||
| checkpointData.cpState = CheckpointState.CheckpointingInProgress | ||
| } | ||
| } | ||
|
|
||
| val newRDD = checkpointData.doCheckpoint() | ||
|
|
||
| // Release doCheckpointStarted after job triggered in checkpoint finished, so | ||
| // that taskBinary serialization can start. | ||
| doCheckpointStarted.release() | ||
| // Wait until taskBinary serialization finished in submitMissingTasksThread. | ||
| taskBinaryBytesFinished.acquire() | ||
|
|
||
| // Update our state and truncate the RDD lineage. | ||
| RDDCheckpointData.synchronized { | ||
| checkpointData.cpRDD = Some(newRDD) | ||
| checkpointData.cpState = CheckpointState.Checkpointed | ||
| rdd.markCheckpointed() | ||
| } | ||
| checkpointStateUpdated.release() | ||
| } | ||
| } | ||
|
|
||
| val submitMissingTasksRunnable = new Runnable { | ||
| override def run() = { | ||
| // Simulate the process of submitMissingTasks. | ||
| // Wait until doCheckpoint job running finished, but checkpoint status not changed. | ||
| doCheckpointStarted.acquire() | ||
|
|
||
| val ser = SparkEnv.get.closureSerializer.newInstance() | ||
|
|
||
| // Simulate task serialization while submitMissingTasks. | ||
| // Task serialized with rdd checkpoint not finished. | ||
| val cleanedFunc = sc.clean(Utils.getIteratorSize _) | ||
| val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it) | ||
| val taskBinaryBytes = JavaUtils.bufferToArray( | ||
| ser.serialize((rdd, func): AnyRef)) | ||
| // Because partition calculate is in a synchronized block, so in the fixed code | ||
| // partition is calculated here. | ||
| val correctPart = rdd.partitions(0) | ||
|
|
||
| // Release taskBinaryBytesFinished so changing checkpoint status to Checkpointed will | ||
| // be done in checkpointThread. | ||
| taskBinaryBytesFinished.release() | ||
| // Wait until checkpoint status changed to Checkpointed in checkpointThread. | ||
| checkpointStateUpdated.acquire() | ||
|
|
||
| // Now we're done simulating the interleaving that might happen within the scheduler, | ||
| // we'll check to make sure the final state is OK by simulating a couple steps that | ||
| // normally happen on the executor. | ||
| // Part calculated with rdd checkpoint already finished. | ||
|
||
| val errPart = rdd.partitions(0) | ||
|
|
||
| // TaskBinary will be deserialized when run task in executor. | ||
| val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, Iterator[Int]) => Unit)]( | ||
| ByteBuffer.wrap(taskBinaryBytes), Thread.currentThread.getContextClassLoader) | ||
|
|
||
| val taskContext = mock(classOf[TaskContext]) | ||
| doNothing().when(taskContext).killTaskIfInterrupted() | ||
|
|
||
| // Make sure our test case is setup correctly -- we expect a ClassCastException here | ||
| // if we use the rdd.partitions after checkpointing was done, but our binary bytes is | ||
| // from before it finished. | ||
| intercept[ClassCastException] { | ||
|
||
| // Triggered when runTask in executor. | ||
| taskRdd.iterator(errPart, taskContext) | ||
| } | ||
|
|
||
| // Execute successfully with correctPart. | ||
| taskRdd.iterator(correctPart, taskContext) | ||
| } | ||
| } | ||
|
|
||
| try { | ||
| new Thread(checkpointRunnable).start() | ||
| val submitMissingTasksThread = new Thread(submitMissingTasksRunnable) | ||
| submitMissingTasksThread.start() | ||
| submitMissingTasksThread.join() | ||
| } finally { | ||
| Utils.deleteRecursively(checkpointDir) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. | ||
| * Note that this checks only the host and not the executor ID. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is the key point for
WrppedPartitionandWrappedRDD, please give comments for explaining your intention.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the comment, i will work on this.