-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status #20244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,20 +17,25 @@ | |
|
|
||
| package org.apache.spark.scheduler | ||
|
|
||
| import java.io.File | ||
| import java.nio.ByteBuffer | ||
| import java.util.Properties | ||
| import java.util.concurrent.Semaphore | ||
| import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} | ||
|
|
||
| import scala.annotation.meta.param | ||
| import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} | ||
| import scala.language.reflectiveCalls | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.mockito.Mockito._ | ||
| import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} | ||
| import org.scalatest.time.SpanSugar._ | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.broadcast.BroadcastManager | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.network.util.JavaUtils | ||
| import org.apache.spark.rdd.{CheckpointState, RDD, RDDCheckpointData} | ||
| import org.apache.spark.scheduler.SchedulingMode.SchedulingMode | ||
| import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} | ||
| import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} | ||
|
|
@@ -96,6 +101,26 @@ class MyRDD( | |
| override def toString: String = "DAGSchedulerSuiteRDD " + id | ||
| } | ||
|
|
||
| /** Wrapped rdd partition. */ | ||
| class WrappedPartition(val partition: Partition) extends Partition { | ||
| def index: Int = partition.index | ||
| } | ||
|
|
||
| /** | ||
| * An RDD with a particular defined Partition which is WrappedPartition. | ||
| * The compute method will cast the split to WrappedPartition. The cast operation will be | ||
| * used in this test suite. | ||
| */ | ||
| class WrappedRDD(parent: RDD[Int]) extends RDD[Int](parent) { | ||
| protected def getPartitions: Array[Partition] = { | ||
| parent.partitions.map(p => new WrappedPartition(p)) | ||
| } | ||
|
|
||
| def compute(split: Partition, context: TaskContext): Iterator[Int] = { | ||
| parent.compute(split.asInstanceOf[WrappedPartition].partition, context) | ||
|
||
| } | ||
| } | ||
|
|
||
| class DAGSchedulerSuiteDummyException extends Exception | ||
|
|
||
| class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { | ||
|
|
@@ -2399,6 +2424,115 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * In this test, we simply simulate the scene in concurrent jobs using the same | ||
| * rdd which is marked to do checkpoint: | ||
| * Job one has already finished the spark job, and start the process of doCheckpoint; | ||
| * Job two is submitted, and submitMissingTasks is called. | ||
| * In submitMissingTasks, if taskSerialization is called before doCheckpoint is done, | ||
| * while part calculates from stage.rdd.partitions is called after doCheckpoint is done, | ||
| * we may get a ClassCastException when execute the task because of some rdd will do | ||
| * Partition cast. | ||
| * | ||
| * With this test case, just want to indicate that we should do taskSerialization and | ||
| * part calculate in submitMissingTasks with the same rdd checkpoint status. | ||
| */ | ||
| test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") { | ||
|
||
| // set checkpointDir. | ||
| val tempDir = Utils.createTempDir() | ||
| val checkpointDir = File.createTempFile("temp", "", tempDir) | ||
| checkpointDir.delete() | ||
|
||
| sc.setCheckpointDir(checkpointDir.toString) | ||
|
|
||
| // Semaphores to control the process sequence for the two threads below. | ||
| val semaphore1 = new Semaphore(0) | ||
| val semaphore2 = new Semaphore(0) | ||
|
|
||
| val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4)) | ||
| rdd.checkpoint() | ||
|
|
||
| val checkpointRunnable = new Runnable { | ||
| override def run() = { | ||
| // Simply simulate what RDD.doCheckpoint() do here. | ||
|
||
| rdd.doCheckpointCalled = true | ||
| val checkpointData = rdd.checkpointData.get | ||
| RDDCheckpointData.synchronized { | ||
| if (checkpointData.cpState == CheckpointState.Initialized) { | ||
| checkpointData.cpState = CheckpointState.CheckpointingInProgress | ||
| } | ||
| } | ||
|
|
||
| val newRDD = checkpointData.doCheckpoint() | ||
|
|
||
| // Release semaphore1 after job triggered in checkpoint finished, so that taskBinary | ||
| // serialization can start. | ||
| semaphore1.release() | ||
| // Wait until taskBinary serialization finished in submitMissingTasksThread. | ||
| semaphore2.acquire() | ||
|
||
|
|
||
| // Update our state and truncate the RDD lineage. | ||
| RDDCheckpointData.synchronized { | ||
| checkpointData.cpRDD = Some(newRDD) | ||
| checkpointData.cpState = CheckpointState.Checkpointed | ||
| rdd.markCheckpointed() | ||
| } | ||
| semaphore1.release() | ||
|
||
| } | ||
| } | ||
|
|
||
| val submitMissingTasksRunnable = new Runnable { | ||
| override def run() = { | ||
| // Simply simulate the process of submitMissingTasks. | ||
| // Wait until doCheckpoint job running finished, but checkpoint status not changed. | ||
| semaphore1.acquire() | ||
|
|
||
| val ser = SparkEnv.get.closureSerializer.newInstance() | ||
|
|
||
| // Simply simulate task serialization while submitMissingTasks. | ||
| // Task serialized with rdd checkpoint not finished. | ||
| val cleanedFunc = sc.clean(Utils.getIteratorSize _) | ||
| val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it) | ||
| val taskBinaryBytes = JavaUtils.bufferToArray( | ||
| ser.serialize((rdd, func): AnyRef)) | ||
| // Because partition calculate is in a synchronized block, so in the fixed code | ||
| // partition is calculated here. | ||
| val correctPart = rdd.partitions(0) | ||
|
|
||
| // Release semaphore2 so changing checkpoint status to Checkpointed will be done in | ||
| // checkpointThread. | ||
| semaphore2.release() | ||
| // Wait until checkpoint status changed to Checkpointed in checkpointThread. | ||
| semaphore1.acquire() | ||
|
|
||
| // Part calculated with rdd checkpoint already finished. | ||
|
||
| val errPart = rdd.partitions(0) | ||
|
|
||
| // TaskBinary will be deserialized when run task in executor. | ||
| val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, Iterator[Int]) => Unit)]( | ||
| ByteBuffer.wrap(taskBinaryBytes), Thread.currentThread.getContextClassLoader) | ||
|
|
||
| val taskContext = mock(classOf[TaskContext]) | ||
| doNothing().when(taskContext).killTaskIfInterrupted() | ||
|
|
||
| // ClassCastException is expected with errPart. | ||
|
||
| intercept[ClassCastException] { | ||
|
||
| // Triggered when runTask in executor. | ||
| taskRdd.iterator(errPart, taskContext) | ||
| } | ||
|
|
||
| // Execute successfully with correctPart. | ||
| taskRdd.iterator(correctPart, taskContext) | ||
| } | ||
| } | ||
|
|
||
| new Thread(checkpointRunnable).start() | ||
| val submitMissingTasksThread = new Thread(submitMissingTasksRunnable) | ||
| submitMissingTasksThread.start() | ||
| submitMissingTasksThread.join() | ||
|
|
||
| Utils.deleteRecursively(tempDir) | ||
|
||
| } | ||
|
|
||
| /** | ||
| * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. | ||
| * Note that this checks only the host and not the executor ID. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd reword this a bit:
taskBinaryBytes and partitions are both effected by the checkpoint status. We need this synchronization in case another concurrent job is checkpointing this RDD, so we get a consistent view of both variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the advise.