Skip to content

Commit 38700ea

Browse files
committed
[SPARK-10381] Fix mixup of taskAttemptNumber & attemptId in OutputCommitCoordinator
When speculative execution is enabled, consider a scenario where the authorized committer of a particular output partition fails during the OutputCommitter.commitTask() call. In this case, the OutputCommitCoordinator is supposed to release that committer's exclusive lock on committing once that task fails. However, due to a unit mismatch (we used task attempt number in one place and task attempt id in another) the lock will not be released, causing Spark to go into an infinite retry loop. This bug was masked by the fact that the OutputCommitCoordinator does not have enough end-to-end tests (the current tests use many mocks). Other factors contributing to this bug are the fact that we have many similarly-named identifiers that have different semantics but the same data types (e.g. attemptNumber and taskAttemptId, with inconsistent variable naming which makes them difficult to distinguish). This patch adds a regression test and fixes this bug by always using task attempt numbers throughout this code. Author: Josh Rosen <joshrosen@databricks.com> Closes apache#8544 from JoshRosen/SPARK-10381.
1 parent 99ecfa5 commit 38700ea

File tree

17 files changed

+174
-69
lines changed

17 files changed

+174
-69
lines changed

core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ class SparkHadoopWriter(jobConf: JobConf)
104104
}
105105

106106
def commit() {
107-
SparkHadoopMapRedUtil.commitTask(
108-
getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID)
107+
SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID)
109108
}
110109

111110
def commitJob() {

core/src/main/scala/org/apache/spark/TaskEndReason.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,12 @@ case object TaskKilled extends TaskFailedReason {
193193
* Task requested the driver to commit, but was denied.
194194
*/
195195
@DeveloperApi
196-
case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason {
196+
case class TaskCommitDenied(
197+
jobID: Int,
198+
partitionID: Int,
199+
attemptNumber: Int) extends TaskFailedReason {
197200
override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
198-
s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
201+
s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber"
199202
/**
200203
* If a task failed because its attempt to commit was denied, do not count this failure
201204
* towards failing the stage. This is intended to prevent spurious stage failures in cases

core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ private[spark] class CommitDeniedException(
2626
msg: String,
2727
jobID: Int,
2828
splitID: Int,
29-
attemptID: Int)
29+
attemptNumber: Int)
3030
extends Exception(msg) {
3131

32-
def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID)
32+
def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber)
3333
}

core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging {
9191
committer: MapReduceOutputCommitter,
9292
mrTaskContext: MapReduceTaskAttemptContext,
9393
jobId: Int,
94-
splitId: Int,
95-
attemptId: Int): Unit = {
94+
splitId: Int): Unit = {
9695

9796
val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext)
9897

@@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging {
122121

123122
if (shouldCoordinateWithDriver) {
124123
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
125-
val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId)
124+
val taskAttemptNumber = TaskContext.get().attemptNumber()
125+
val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber)
126126

127127
if (canCommit) {
128128
performCommit()
@@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging {
132132
logInfo(message)
133133
// We need to abort the task so that the driver can reschedule new attempts, if necessary
134134
committer.abortTask(mrTaskContext)
135-
throw new CommitDeniedException(message, jobId, splitId, attemptId)
135+
throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber)
136136
}
137137
} else {
138138
// Speculation is disabled or a user has chosen to manually bypass the commit coordination
@@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging {
143143
logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
144144
}
145145
}
146-
147-
def commitTask(
148-
committer: MapReduceOutputCommitter,
149-
mrTaskContext: MapReduceTaskAttemptContext,
150-
sparkTaskContext: TaskContext): Unit = {
151-
commitTask(
152-
committer,
153-
mrTaskContext,
154-
sparkTaskContext.stageId(),
155-
sparkTaskContext.partitionId(),
156-
sparkTaskContext.attemptNumber())
157-
}
158146
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,8 +1128,11 @@ class DAGScheduler(
11281128
val stageId = task.stageId
11291129
val taskType = Utils.getFormattedClassName(task)
11301130

1131-
outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
1132-
event.taskInfo.attempt, event.reason)
1131+
outputCommitCoordinator.taskCompleted(
1132+
stageId,
1133+
task.partitionId,
1134+
event.taskInfo.attemptNumber, // this is a task attempt number
1135+
event.reason)
11331136

11341137
// The success case is dealt with separately below, since we need to compute accumulator
11351138
// updates before posting.

core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint
2525
private sealed trait OutputCommitCoordinationMessage extends Serializable
2626

2727
private case object StopCoordinator extends OutputCommitCoordinationMessage
28-
private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long)
28+
private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int)
2929

3030
/**
3131
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4444
var coordinatorRef: Option[RpcEndpointRef] = None
4545

4646
private type StageId = Int
47-
private type PartitionId = Long
48-
private type TaskAttemptId = Long
47+
private type PartitionId = Int
48+
private type TaskAttemptNumber = Int
4949

5050
/**
5151
* Map from active stages's id => partition id => task attempt with exclusive lock on committing
@@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
5757
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
5858
*/
5959
private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
60-
private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]
60+
private type CommittersByStageMap =
61+
mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]]
6162

6263
/**
6364
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
7576
*
7677
* @param stage the stage number
7778
* @param partition the partition number
78-
* @param attempt a unique identifier for this task attempt
79+
* @param attemptNumber how many times this task has been attempted
80+
* (see [[TaskContext.attemptNumber()]])
7981
* @return true if this task is authorized to commit, false otherwise
8082
*/
8183
def canCommit(
8284
stage: StageId,
8385
partition: PartitionId,
84-
attempt: TaskAttemptId): Boolean = {
85-
val msg = AskPermissionToCommitOutput(stage, partition, attempt)
86+
attemptNumber: TaskAttemptNumber): Boolean = {
87+
val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber)
8688
coordinatorRef match {
8789
case Some(endpointRef) =>
8890
endpointRef.askWithRetry[Boolean](msg)
@@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
9597

9698
// Called by DAGScheduler
9799
private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
98-
authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]()
100+
authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]()
99101
}
100102

101103
// Called by DAGScheduler
@@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
107109
private[scheduler] def taskCompleted(
108110
stage: StageId,
109111
partition: PartitionId,
110-
attempt: TaskAttemptId,
112+
attemptNumber: TaskAttemptNumber,
111113
reason: TaskEndReason): Unit = synchronized {
112114
val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
113115
logDebug(s"Ignoring task completion for completed stage")
@@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
117119
case Success =>
118120
// The task output has been committed successfully
119121
case denied: TaskCommitDenied =>
120-
logInfo(
121-
s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt")
122+
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
123+
s"attempt: $attemptNumber")
122124
case otherReason =>
123-
if (authorizedCommitters.get(partition).exists(_ == attempt)) {
124-
logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" +
125-
s" clearing lock")
125+
if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) {
126+
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
127+
s"partition=$partition) failed; clearing lock")
126128
authorizedCommitters.remove(partition)
127129
}
128130
}
@@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
140142
private[scheduler] def handleAskPermissionToCommit(
141143
stage: StageId,
142144
partition: PartitionId,
143-
attempt: TaskAttemptId): Boolean = synchronized {
145+
attemptNumber: TaskAttemptNumber): Boolean = synchronized {
144146
authorizedCommittersByStage.get(stage) match {
145147
case Some(authorizedCommitters) =>
146148
authorizedCommitters.get(partition) match {
147149
case Some(existingCommitter) =>
148-
logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " +
149-
s"existingCommitter = $existingCommitter")
150+
logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
151+
s"partition=$partition; existingCommitter = $existingCommitter")
150152
false
151153
case None =>
152-
logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition")
153-
authorizedCommitters(partition) = attempt
154+
logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
155+
s"partition=$partition")
156+
authorizedCommitters(partition) = attemptNumber
154157
true
155158
}
156159
case None =>
157-
logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit")
160+
logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" +
161+
s"partition $partition to commit")
158162
false
159163
}
160164
}
@@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator {
174178
}
175179

176180
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
177-
case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
181+
case AskPermissionToCommitOutput(stage, partition, attemptNumber) =>
178182
context.reply(
179-
outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt))
183+
outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
180184
}
181185
}
182186
}

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
2929
class TaskInfo(
3030
val taskId: Long,
3131
val index: Int,
32-
val attempt: Int,
32+
val attemptNumber: Int,
3333
val launchTime: Long,
3434
val executorId: String,
3535
val host: String,
@@ -95,7 +95,10 @@ class TaskInfo(
9595
}
9696
}
9797

98-
def id: String = s"$index.$attempt"
98+
@deprecated("Use attemptNumber", "1.6.0")
99+
def attempt: Int = attemptNumber
100+
101+
def id: String = s"$index.$attemptNumber"
99102

100103
def duration: Long = {
101104
if (!finished) {

core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ private[v1] object AllStagesResource {
127127
new TaskData(
128128
taskId = uiData.taskInfo.taskId,
129129
index = uiData.taskInfo.index,
130-
attempt = uiData.taskInfo.attempt,
130+
attempt = uiData.taskInfo.attemptNumber,
131131
launchTime = new Date(uiData.taskInfo.launchTime),
132132
executorId = uiData.taskInfo.executorId,
133133
host = uiData.taskInfo.host,

core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
621621
serializationTimeProportionPos + serializationTimeProportion
622622

623623
val index = taskInfo.index
624-
val attempt = taskInfo.attempt
624+
val attempt = taskInfo.attemptNumber
625625

626626
val svgTag =
627627
if (totalExecutionTime == 0) {
@@ -967,7 +967,7 @@ private[ui] class TaskDataSource(
967967
new TaskTableRowData(
968968
info.index,
969969
info.taskId,
970-
info.attempt,
970+
info.attemptNumber,
971971
info.speculative,
972972
info.status,
973973
info.taskLocality.toString,

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ private[spark] object JsonProtocol {
266266
def taskInfoToJson(taskInfo: TaskInfo): JValue = {
267267
("Task ID" -> taskInfo.taskId) ~
268268
("Index" -> taskInfo.index) ~
269-
("Attempt" -> taskInfo.attempt) ~
269+
("Attempt" -> taskInfo.attemptNumber) ~
270270
("Launch Time" -> taskInfo.launchTime) ~
271271
("Executor ID" -> taskInfo.executorId) ~
272272
("Host" -> taskInfo.host) ~

0 commit comments

Comments
 (0)