Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object TaskContext {
* An empty task context that does not represent an actual task. This is only used in tests.
*/
private[spark] def empty(): TaskContextImpl = {
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
}
}

Expand Down Expand Up @@ -150,6 +150,11 @@ abstract class TaskContext extends Serializable {
*/
def stageId(): Int

/**
* An ID that is unique to the stage attempt that this task belongs to.
*/
def stageAttemptId(): Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should call it stageAttempNumber to be consistent with taskAttemptNumber. Also let's follow the comment of attemptNumber

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if we are defining stageAttemptId from scratch, I would go for stageAttemptNumber. However stageAttemptId are already used elsewhere in the codebase, Like in Task.scala. I think it's more important to be consistent.

However I could update the comment to reflect the attempt number part if you wish

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that, internally we use stageAttemptId, and internally we call TaskContext.taskAttemptId taskId. However, for end users, they don't know the internal code, and they are more familiar with TaskContext. I think the naming should be consistent with the public API TaskContext, instead of internal code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no objection for either ‘id’ nor ‘number’, they are both reasonable.

I am on train now。 If no other input, I can rename it to ‘stageAttemptNumber’ since you insisted.


/**
* The ID of the RDD partition that is computed by this task.
*/
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.util._
*/
private[spark] class TaskContextImpl(
val stageId: Int,
val stageAttemptId: Int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add override. Since you are touching this file, could you also add override to stageId and partitionId.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Would you tell me the difference or rationale?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's kind of a code style standard: add override if it is override.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK then.

val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private[spark] abstract class Task[T](
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
stageId,
stageAttemptId,
partitionId,
taskAttemptId,
attemptNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public static void test() {
tc.attemptNumber();
tc.partitionId();
tc.stageId();
tc.stageAttemptId();
tc.taskAttemptId();
}

Expand All @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
context.stageId();
context.stageAttemptId();
context.partitionId();
context.addTaskCompletionListener(this);
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC

// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
val data1 = (1 to 10).map { x => x -> x}

// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
val data2 = (11 to 20).map { x => x -> x}

// interleave writes of both attempts -- we want to test that both attempts can occur
Expand Down Expand Up @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}

val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object MemoryTestingUtils {
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
new TaskContextImpl(
stageId = 0,
stageAttemptId = 0,
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
Expand Down Expand Up @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
}

test("TaskContext.stageAttemptId getter") {
sc = new SparkContext("local[1,2]", "test")

// Check stage attemptIds are 0 for initial stage
val stageAttemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
Seq(TaskContext.get().stageAttemptId()).iterator
}.collect()
assert(stageAttemptIds.toSet === Set(0))

// Check stage attemptIds that are resubmitted when tasks have FetchFailedException
val stageAttemptIdsWithFailedStage =
sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need repartition here, just sc.parallelize(Seq(1, 2, 3, 4), 1).mapPartitions {...}

val stageAttemptId = TaskContext.get().stageAttemptId()
if (stageAttemptId < 2) {
// Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception
// will only trigger task resubmission in the same stage.
throw new FetchFailedException(null, 0, 0, 0, "Fake")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emmm... just throw an Exception is enough here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to repartition part.

I use FetchFailedException to explicitly trigger a stage resubmission. Otherwise, the task would be resubmitted in the same stage if IIRC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comment to explain that FetchFailedException will trigger a new stage attempt, while a common Exception will only trigger a task retry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

}
Seq(stageAttemptId).iterator
}.collect()

assert(stageAttemptIdsWithFailedStage.toSet === Set(2))
}

test("accumulators are updated on exception failures") {
// This means use 1 core and 4 max task failures
sc = new SparkContext("local[1,4]", "test")
Expand Down Expand Up @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// accumulator updates from it.
val taskMetrics = TaskMetrics.empty
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0L, 0,
context = new TaskContextImpl(0, 0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
Expand All @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
// accumulator updates from it.
val taskMetrics = TaskMetrics.registered
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0L, 0,
context = new TaskContextImpl(0, 0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
SparkEnv.get.metricsSystem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
TaskContext.setTaskContext(
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ object MimaExcludes {

// Exclude rules for 2.3.x
lazy val v23excludes = v22excludes ++ Seq(
// [SPARK-22897] Expose stageAttemptId in TaskContext
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptId"),

// SPARK-22789: Map-only continuous processing execution
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite

TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
stageAttemptId = 0,
partitionId = 0,
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
stageAttemptId = 0,
partitionId = 0,
taskAttemptId = 98456,
attemptNumber = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte
val conf = new SparkConf()
sc = new SparkContext("local[2, 4]", "test", conf)
val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0)
TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null))
TaskContext.setTaskContext(
new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null))
}

override def afterAll(): Unit = TaskContext.unset()
Expand Down