Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about directly call taskSetManager.handleFailedTask here?
If canFetchMoreResults return false, taskSetManger.isZombie has set to true. scheduler.handlerFailedTask equally same with taskSetManager.handleFailedTask, and this will make UT easy to write.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling scheduler.handleFailedTask is to be consistent with other cases in this function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to leave a comment here to explain why we handle the oversize task as a killed task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks.

"Tasks result size has exceeded maxResultSize"))
return
}
// deserialize "value" without holding any lock so that it won't block other threads.
Expand All @@ -75,6 +78,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
sparkEnv.blockManager.master.removeBlock(blockId)
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
return
}
logDebug("Fetching indirect task result for TID %s".format(tid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE
import org.apache.spark.storage.TaskResultBlockId
Expand Down Expand Up @@ -78,6 +79,16 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task
}
}

private class DummyTaskSchedulerImpl(sc: SparkContext)
extends TaskSchedulerImpl(sc, 1, true) {
override def handleFailedTask(
taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: TaskFailedReason): Unit = {
// do nothing
}
}

/**
* A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
Expand Down Expand Up @@ -130,6 +141,31 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
"Expect result to be removed from the block manager.")
}

test("handling total size of results larger than maxResultSize") {
sc = new SparkContext("local", "test", conf)
val scheduler = new DummyTaskSchedulerImpl(sc)
val spyScheduler = spy(scheduler)
val resultGetter = new TaskResultGetter(sc.env, spyScheduler)
scheduler.taskResultGetter = resultGetter
val myTsm = new TaskSetManager(spyScheduler, FakeTask.createTaskSet(2), 1) {
// always returns false
override def canFetchMoreResults(size: Long): Boolean = false
}
val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0)
val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array())
val ser = sc.env.closureSerializer.newInstance()
val serializedIndirect = ser.serialize(indirectTaskResult)
val serializedDirect = ser.serialize(directTaskResult)
resultGetter.enqueueSuccessfulTask(myTsm, 0, serializedDirect)
resultGetter.enqueueSuccessfulTask(myTsm, 1, serializedIndirect)
eventually(timeout(1.second)) {
verify(spyScheduler, times(1)).handleFailedTask(
myTsm, 0, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize"))
verify(spyScheduler, times(1)).handleFailedTask(
myTsm, 1, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize"))
}
}

test("task retried if result missing from block manager") {
// Set the maximum number of task failures to > 0, so that the task set isn't aborted
// after the result is missing.
Expand Down