1818package org .apache .spark .scheduler
1919
2020import java .nio .ByteBuffer
21- import java .util .HashSet
2221
2322import scala .collection .mutable .HashMap
23+ import scala .collection .mutable .Set
2424import scala .concurrent .duration ._
2525
2626import org .mockito .Matchers .{anyInt , anyObject , anyString , eq => meq }
@@ -40,7 +40,7 @@ class FakeSchedulerBackend extends SchedulerBackend {
4040 def reviveOffers () {}
4141 def defaultParallelism (): Int = 1
4242 def maxNumConcurrentTasks (): Int = 0
43- val killedTaskIds : HashSet [Long ] = new HashSet [Long ]()
43+ val killedTaskIds : Set [Long ] = Set [Long ]()
4444 override def killTask (
4545 taskId : Long ,
4646 executorId : String ,
@@ -1328,22 +1328,30 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
13281328 tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState .FAILED , TaskKilled (" test" ))
13291329 assert(tsm.isZombie)
13301330 }
1331+
13311332 test(" SPARK-25250 On successful completion of a task attempt on a partition id, kill other" +
13321333 " running task attempts on that same partition" ) {
13331334 val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
1335+
13341336 val firstAttempt = FakeTask .createTaskSet(10 , stageAttemptId = 0 )
13351337 taskScheduler.submitTasks(firstAttempt)
1338+
13361339 val offersFirstAttempt = (0 until 10 ).map{ idx => WorkerOffer (s " exec- $idx" , s " host- $idx" , 1 ) }
13371340 taskScheduler.resourceOffers(offersFirstAttempt)
1341+
13381342 val tsm0 = taskScheduler.taskSetManagerForAttempt(0 , 0 ).get
13391343 val matchingTaskInfoFirstAttempt = tsm0.taskAttempts(0 ).head
13401344 tsm0.handleFailedTask(matchingTaskInfoFirstAttempt.taskId, TaskState .FAILED ,
13411345 FetchFailed (null , 0 , 0 , 0 , " fetch failed" ))
1346+
13421347 val secondAttempt = FakeTask .createTaskSet(10 , stageAttemptId = 1 )
13431348 taskScheduler.submitTasks(secondAttempt)
1349+
13441350 val offersSecondAttempt = (0 until 10 ).map{ idx => WorkerOffer (s " exec- $idx" , s " host- $idx" , 1 ) }
13451351 taskScheduler.resourceOffers(offersSecondAttempt)
1352+
13461353 taskScheduler.markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(2 , 0 )
1354+
13471355 val tsm1 = taskScheduler.taskSetManagerForAttempt(0 , 1 ).get
13481356 val indexInTsm = tsm1.partitionToIndex(2 )
13491357 val matchingTaskInfoSecondAttempt = tsm1.taskAttempts.flatten.filter(_.index == indexInTsm).head
0 commit comments