Skip to content

Commit 9f0deec

Browse files
authored
Merge branch 'master' into SPARK-23930
2 parents 9d65570 + 4d5de4d commit 9f0deec

File tree

97 files changed

+2203
-444
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+2203
-444
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.unsafe.types;
1919

20-
import org.apache.spark.unsafe.Platform;
21-
2220
import java.util.Arrays;
2321

22+
import com.google.common.primitives.Ints;
23+
24+
import org.apache.spark.unsafe.Platform;
25+
2426
public final class ByteArray {
2527

2628
public static final byte[] EMPTY_BYTE = new byte[0];
@@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
7779

7880
public static byte[] concat(byte[]... inputs) {
7981
// Compute the total length of the result
80-
int totalLength = 0;
82+
long totalLength = 0;
8183
for (int i = 0; i < inputs.length; i++) {
8284
if (inputs[i] != null) {
83-
totalLength += inputs[i].length;
85+
totalLength += (long)inputs[i].length;
8486
} else {
8587
return null;
8688
}
8789
}
8890

8991
// Allocate a new byte array, and copy the inputs one by one into it
90-
final byte[] result = new byte[totalLength];
92+
final byte[] result = new byte[Ints.checkedCast(totalLength)];
9193
int offset = 0;
9294
for (int i = 0; i < inputs.length; i++) {
9395
int len = inputs[i].length;

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import com.esotericsoftware.kryo.KryoSerializable;
3030
import com.esotericsoftware.kryo.io.Input;
3131
import com.esotericsoftware.kryo.io.Output;
32-
3332
import com.google.common.primitives.Ints;
33+
3434
import org.apache.spark.unsafe.Platform;
3535
import org.apache.spark.unsafe.array.ByteArrayMethods;
3636
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
@@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) {
877877
*/
878878
public static UTF8String concat(UTF8String... inputs) {
879879
// Compute the total length of the result.
880-
int totalLength = 0;
880+
long totalLength = 0;
881881
for (int i = 0; i < inputs.length; i++) {
882882
if (inputs[i] != null) {
883-
totalLength += inputs[i].numBytes;
883+
totalLength += (long)inputs[i].numBytes;
884884
} else {
885885
return null;
886886
}
887887
}
888888

889889
// Allocate a new byte array, and copy the inputs one by one into it.
890-
final byte[] result = new byte[totalLength];
890+
final byte[] result = new byte[Ints.checkedCast(totalLength)];
891891
int offset = 0;
892892
for (int i = 0; i < inputs.length; i++) {
893893
int len = inputs[i].numBytes;

core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
2525
import com.google.common.io.Files
2626

2727
import org.apache.spark.{SecurityManager, SparkConf}
28-
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
28+
import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
2929
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
3030
import org.apache.spark.internal.Logging
3131
import org.apache.spark.rpc.RpcEndpointRef
@@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner(
142142
private def fetchAndRunExecutor() {
143143
try {
144144
// Launch the process
145-
val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf),
145+
val subsOpts = appDesc.command.javaOpts.map {
146+
Utils.substituteAppNExecIds(_, appId, execId.toString)
147+
}
148+
val subsCommand = appDesc.command.copy(javaOpts = subsOpts)
149+
val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf),
146150
memory, sparkHome.getAbsolutePath, substituteVariables)
147151
val command = builder.command()
148152
val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"")

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl(
689689
}
690690
}
691691

692+
/**
693+
* Marks the task has completed in all TaskSetManagers for the given stage.
694+
*
695+
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
696+
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
697+
* do not also submit those same tasks. That also means that a task completion from an earlier
698+
* attempt can lead to the entire stage getting marked as successful.
699+
*/
700+
private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
701+
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
702+
tsm.markPartitionCompleted(partitionId)
703+
}
704+
}
705+
692706
}
693707

694708

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ private[spark] class TaskSetManager(
7373
val ser = env.closureSerializer.newInstance()
7474

7575
val tasks = taskSet.tasks
76+
private[scheduler] val partitionToIndex = tasks.zipWithIndex
77+
.map { case (t, idx) => t.partitionId -> idx }.toMap
7678
val numTasks = tasks.length
7779
val copiesRunning = new Array[Int](numTasks)
7880

@@ -153,7 +155,7 @@ private[spark] class TaskSetManager(
153155
private[scheduler] val speculatableTasks = new HashSet[Int]
154156

155157
// Task index, start and finish time for each task attempt (indexed by task ID)
156-
private val taskInfos = new HashMap[Long, TaskInfo]
158+
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
157159

158160
// Use a MedianHeap to record durations of successful tasks so we know when to launch
159161
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -754,6 +756,9 @@ private[spark] class TaskSetManager(
754756
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
755757
" because task " + index + " has already completed successfully")
756758
}
759+
// There may be multiple tasksets for this stage -- we let all of them know that the partition
760+
// was completed. This may result in some of the tasksets getting completed.
761+
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
757762
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
758763
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
759764
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -764,6 +769,19 @@ private[spark] class TaskSetManager(
764769
maybeFinishTaskSet()
765770
}
766771

772+
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
773+
partitionToIndex.get(partitionId).foreach { index =>
774+
if (!successful(index)) {
775+
tasksSuccessful += 1
776+
successful(index) = true
777+
if (tasksSuccessful == numTasks) {
778+
isZombie = true
779+
}
780+
maybeFinishTaskSet()
781+
}
782+
}
783+
}
784+
767785
/**
768786
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
769787
* DAG Scheduler.

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T](
486486
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
487487
private[spark] var _value = initialValue // Current value on driver
488488

489-
override def isZero: Boolean = _value == param.zero(initialValue)
489+
@transient private lazy val _zero = param.zero(initialValue)
490+
491+
override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
490492

491493
override def copy(): LegacyAccumulatorWrapper[R, T] = {
492494
val acc = new LegacyAccumulatorWrapper(initialValue, param)
@@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T](
495497
}
496498

497499
override def reset(): Unit = {
498-
_value = param.zero(initialValue)
500+
_value = _zero
499501
}
500502

501503
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,6 +2689,21 @@ private[spark] object Utils extends Logging {
26892689

26902690
s"k8s://$resolvedURL"
26912691
}
2692+
2693+
/**
2694+
* Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id
2695+
* and {{APP_ID}} occurrences with the App Id.
2696+
*/
2697+
def substituteAppNExecIds(opt: String, appId: String, execId: String): String = {
2698+
opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId)
2699+
}
2700+
2701+
/**
2702+
* Replaces all the {{APP_ID}} occurrences with the App Id.
2703+
*/
2704+
def substituteAppId(opt: String, appId: String): String = {
2705+
opt.replace("{{APP_ID}}", appId)
2706+
}
26922707
}
26932708

26942709
private[util] object CallerContext extends Logging {

core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
6363
*/
6464
def writeFully(channel: WritableByteChannel): Unit = {
6565
for (bytes <- getChunks()) {
66-
while (bytes.remaining() > 0) {
67-
val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
68-
bytes.limit(bytes.position() + ioSize)
69-
channel.write(bytes)
66+
val curChunkLimit = bytes.limit()
67+
while (bytes.hasRemaining) {
68+
try {
69+
val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
70+
bytes.limit(bytes.position() + ioSize)
71+
channel.write(bytes)
72+
} finally {
73+
bytes.limit(curChunkLimit)
74+
}
7075
}
7176
}
7277
}

core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import java.nio.ByteBuffer
2121

2222
import com.google.common.io.ByteStreams
2323

24-
import org.apache.spark.SparkFunSuite
24+
import org.apache.spark.{SharedSparkContext, SparkFunSuite}
25+
import org.apache.spark.internal.config
2526
import org.apache.spark.network.util.ByteArrayWritableChannel
2627
import org.apache.spark.util.io.ChunkedByteBuffer
2728

28-
class ChunkedByteBufferSuite extends SparkFunSuite {
29+
class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
2930

3031
test("no chunks") {
3132
val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
@@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
5657
assert(chunkedByteBuffer.getChunks().head.position() === 0)
5758
}
5859

60+
test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") {
61+
try {
62+
sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L)
63+
val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024)))
64+
val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt)
65+
chunkedByteBuffer.writeFully(byteArrayWritableChannel)
66+
assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size)
67+
} finally {
68+
sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE)
69+
}
70+
}
71+
5972
test("toArray()") {
6073
val empty = ByteBuffer.wrap(Array.empty[Byte])
6174
val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
917917
taskScheduler.initialize(new FakeSchedulerBackend)
918918
}
919919
}
920+
921+
test("Completions in zombie tasksets update status of non-zombie taskset") {
922+
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
923+
val valueSer = SparkEnv.get.serializer.newInstance()
924+
925+
def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
926+
val indexInTsm = tsm.partitionToIndex(partition)
927+
val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
928+
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
929+
tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
930+
}
931+
932+
// Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
933+
// two times, so we have three active task sets for one stage. (For this to really happen,
934+
// you'd need the previous stage to also get restarted, and then succeed, in between each
935+
// attempt, but that happens outside what we're mocking here.)
936+
val zombieAttempts = (0 until 2).map { stageAttempt =>
937+
val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
938+
taskScheduler.submitTasks(attempt)
939+
val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
940+
val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
941+
taskScheduler.resourceOffers(offers)
942+
assert(tsm.runningTasks === 10)
943+
// fail attempt
944+
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
945+
FetchFailed(null, 0, 0, 0, "fetch failed"))
946+
// the attempt is a zombie, but the tasks are still running (this could be true even if
947+
// we actively killed those tasks, as killing is best-effort)
948+
assert(tsm.isZombie)
949+
assert(tsm.runningTasks === 9)
950+
tsm
951+
}
952+
953+
// we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
954+
// the stage, but this time with insufficient resources so not all tasks are active.
955+
956+
val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
957+
taskScheduler.submitTasks(finalAttempt)
958+
val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
959+
val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
960+
val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
961+
finalAttempt.tasks(task.index).partitionId
962+
}.toSet
963+
assert(finalTsm.runningTasks === 5)
964+
assert(!finalTsm.isZombie)
965+
966+
// We simulate late completions from our zombie tasksets, corresponding to all the pending
967+
// partitions in our final attempt. This means we're only waiting on the tasks we've already
968+
// launched.
969+
val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
970+
finalAttemptPendingPartitions.foreach { partition =>
971+
completeTaskSuccessfully(zombieAttempts(0), partition)
972+
}
973+
974+
// If there is another resource offer, we shouldn't run anything. Though our final attempt
975+
// used to have pending tasks, now those tasks have been completed by zombie attempts. The
976+
// remaining tasks to compute are already active in the non-zombie attempt.
977+
assert(
978+
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
979+
980+
val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
981+
982+
// finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
983+
// marked as zombie.
984+
// for each of the remaining tasks, find the tasksets with an active copy of the task, and
985+
// finish the task.
986+
remainingTasks.foreach { partition =>
987+
val tsm = if (partition == 0) {
988+
// we failed this task on both zombie attempts, this one is only present in the latest
989+
// taskset
990+
finalTsm
991+
} else {
992+
// should be active in every taskset. We choose a zombie taskset just to make sure that
993+
// we transition the active taskset correctly even if the final completion comes
994+
// from a zombie.
995+
zombieAttempts(partition % 2)
996+
}
997+
completeTaskSuccessfully(tsm, partition)
998+
}
999+
1000+
assert(finalTsm.isZombie)
1001+
1002+
// no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
1003+
verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject())
1004+
1005+
// finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
1006+
// else succeeds, to make sure we get the right updates to the blacklist in all cases.
1007+
(zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
1008+
val stageAttempt = tsm.taskSet.stageAttemptId
1009+
tsm.runningTasksSet.foreach { index =>
1010+
if (stageAttempt == 1) {
1011+
tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
1012+
} else {
1013+
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
1014+
tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
1015+
}
1016+
}
1017+
1018+
// we update the blacklist for the stage attempts with all successful tasks. Even though
1019+
// some tasksets had failures, we still consider them all successful from a blacklisting
1020+
// perspective, as the failures weren't from a problem w/ the tasks themselves.
1021+
verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
1022+
}
1023+
}
9201024
}

0 commit comments

Comments
 (0)