@@ -47,6 +47,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4747 private type PartitionId = Int
4848 private type TaskAttemptNumber = Int
4949
50+ private val NO_AUTHORIZED_COMMITTER : TaskAttemptNumber = - 1
51+
5052 /**
5153 * Map from active stages's id => partition id => task attempt with exclusive lock on committing
5254 * output for that partition.
@@ -56,9 +58,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
5658 *
5759 * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
5860 */
59- private val authorizedCommittersByStage : CommittersByStageMap = mutable.Map ()
60- private type CommittersByStageMap =
61- mutable.Map [StageId , mutable.Map [PartitionId , TaskAttemptNumber ]]
61+ private val authorizedCommittersByStage = mutable.Map [StageId , Array [TaskAttemptNumber ]]()
6262
6363 /**
6464 * Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -95,9 +95,21 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
9595 }
9696 }
9797
98- // Called by DAGScheduler
99- private [scheduler] def stageStart (stage : StageId ): Unit = synchronized {
100- authorizedCommittersByStage(stage) = mutable.HashMap [PartitionId , TaskAttemptNumber ]()
98+ /**
99+ * Called by the DAGScheduler when a stage starts.
100+ *
101+ * @param stage the stage id.
102+ * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
103+ * the maximum possible value of `context.partitionId`).
104+ */
105+ private [scheduler] def stageStart (
106+ stage : StageId ,
107+ maxPartitionId : Int ): Unit = {
108+ val arr = new Array [TaskAttemptNumber ](maxPartitionId + 1 )
109+ java.util.Arrays .fill(arr, NO_AUTHORIZED_COMMITTER )
110+ synchronized {
111+ authorizedCommittersByStage(stage) = arr
112+ }
101113 }
102114
103115 // Called by DAGScheduler
@@ -122,10 +134,10 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
122134 logInfo(s " Task was denied committing, stage: $stage, partition: $partition, " +
123135 s " attempt: $attemptNumber" )
124136 case otherReason =>
125- if (authorizedCommitters.get (partition).exists(_ == attemptNumber) ) {
137+ if (authorizedCommitters(partition) == attemptNumber) {
126138 logDebug(s " Authorized committer (attemptNumber= $attemptNumber, stage= $stage, " +
127139 s " partition= $partition) failed; clearing lock " )
128- authorizedCommitters.remove (partition)
140+ authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
129141 }
130142 }
131143 }
@@ -145,16 +157,16 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
145157 attemptNumber : TaskAttemptNumber ): Boolean = synchronized {
146158 authorizedCommittersByStage.get(stage) match {
147159 case Some (authorizedCommitters) =>
148- authorizedCommitters.get(partition) match {
149- case Some (existingCommitter) =>
150- logDebug(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
151- s " partition= $partition; existingCommitter = $existingCommitter" )
152- false
153- case None =>
160+ authorizedCommitters(partition) match {
161+ case NO_AUTHORIZED_COMMITTER =>
154162 logDebug(s " Authorizing attemptNumber= $attemptNumber to commit for stage= $stage, " +
155163 s " partition= $partition" )
156164 authorizedCommitters(partition) = attemptNumber
157165 true
166+ case existingCommitter =>
167+ logDebug(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
168+ s " partition= $partition; existingCommitter = $existingCommitter" )
169+ false
158170 }
159171 case None =>
160172 logDebug(s " Stage $stage has completed, so not allowing attempt number $attemptNumber of " +
0 commit comments