Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ private[spark] class TaskContextImpl(
// Whether the task has completed.
@volatile private var completed: Boolean = false

// Whether the task has failed.
@volatile private var failed: Boolean = false

override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
Expand All @@ -63,10 +66,13 @@ private[spark] class TaskContextImpl(
this
}

/** Marks the task as completed and triggers the failure listeners. */
/** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit = {
// failure callbacks should only be called once
if (failed) return
failed = true
val errorMsgs = new ArrayBuffer[String](2)
// Process complete callbacks in the reverse order of registration
// Process failure callbacks in the reverse order of registration
onFailureCallbacks.reverse.foreach { listener =>
try {
listener.onTaskFailure(this, error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
Utils.tryWithSafeFinally {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
Expand Down Expand Up @@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.open()
var recordsWritten = 0L

Utils.tryWithSafeFinally {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
Expand Down
39 changes: 38 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging {
* exception from the original `out.write` call.
*/
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
// It would be nice to find a method on Try that did this
var originalThrowable: Throwable = null
try {
block
Expand All @@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging {
}
}

/**
* Execute a block of code, call the failure callbacks before finally block if there is any
* exceptions happen. But if exceptions happen in the finally block, do not suppress the original
* exception.
*
* This is primarily an issue with `finally { out.close() }` blocks, where
* close needs to be called to clean up `out`, but if an exception happened
* in `out.write`, it's likely `out` may be corrupted and `out.close` will
* fail as well. This would then suppress the original/likely more meaningful
* exception from the original `out.write` call.
*/
def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
var originalThrowable: Throwable = null
try {
block
} catch {
case t: Throwable =>
// Purposefully not using NonFatal, because even fatal exceptions
// we don't want to have our finallyBlock suppress
originalThrowable = t
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
throw originalThrowable
} finally {
try {
finallyBlock
} catch {
case t: Throwable =>
if (originalThrowable != null) {
originalThrowable.addSuppressed(t)
logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
throw originalThrowable
} else {
throw t
}
}
}
}

/** Default filtering function for finding call sites using `getCallSite`. */
private def sparkInternalExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the internal Spark API's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.rdd

import java.io.IOException

import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random

Expand All @@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext}
import org.apache.hadoop.util.Progressable

import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite}
import org.apache.spark._
import org.apache.spark.Partitioner
import org.apache.spark.util.Utils

class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
Expand Down Expand Up @@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
}

test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") {
val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)

FakeWriterWithCallback.calledBy = ""
FakeWriterWithCallback.exception = null
val e = intercept[SparkException] {
pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
}
assert(e.getMessage contains "failed to write")

assert(FakeWriterWithCallback.calledBy === "write,callback,close")
assert(FakeWriterWithCallback.exception != null, "exception should be captured")
assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
}

test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") {
val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
val conf = new JobConf()

FakeWriterWithCallback.calledBy = ""
FakeWriterWithCallback.exception = null
val e = intercept[SparkException] {
pairs.saveAsHadoopFile(
"ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
}
assert(e.getMessage contains "failed to write")

assert(FakeWriterWithCallback.calledBy === "write,callback,close")
assert(FakeWriterWithCallback.exception != null, "exception should be captured")
assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
}

test("lookup") {
val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7)))

Expand Down Expand Up @@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
}
}

object FakeWriterWithCallback {
var calledBy: String = ""
var exception: Throwable = _

def onFailure(ctx: TaskContext, e: Throwable): Unit = {
calledBy += "callback,"
exception = e
}
}

class FakeWriterWithCallback extends FakeWriter {

override def close(p1: Reporter): Unit = {
FakeWriterWithCallback.calledBy += "close"
}

override def write(p1: Integer, p2: Integer): Unit = {
FakeWriterWithCallback.calledBy += "write,"
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
FakeWriterWithCallback.onFailure(t, e)
}
throw new IOException("failed to write")
}
}

class FakeFormatWithCallback() extends FakeOutputFormat {
override def getRecordWriter(
ignored: FileSystem,
job: JobConf, name: String,
progress: Progressable): RecordWriter[Integer, Integer] = {
new FakeWriterWithCallback()
}
}

class NewFakeWriterWithCallback extends NewFakeWriter {
override def close(p1: NewTaskAttempContext): Unit = {
FakeWriterWithCallback.calledBy += "close"
}

override def write(p1: Integer, p2: Integer): Unit = {
FakeWriterWithCallback.calledBy += "write,"
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
FakeWriterWithCallback.onFailure(t, e)
}
throw new IOException("failed to write")
}
}

class NewFakeFormatWithCallback() extends NewFakeFormat {
override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
new NewFakeWriterWithCallback()
}
}

class ConfigTestFormat() extends NewFakeFormat() with Configurable {

var setConfCalled = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer(
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
val writer = newOutputWriter(getWorkPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)

var writerClosed = false

// If anything below fails, we should abort the task.
try {
while (iterator.hasNext) {
Expand All @@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer(
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe more clear if we move this into abortTask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually never mind - more clear here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't access cause in abortTask.

abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}

def commitTask(): Unit = {
try {
assert(writer != null, "OutputWriter instance should have been initialized")
if (!writerClosed) {
if (writer != null) {
writer.close()
writerClosed = true
writer = null
}
super.commitTask()
} catch {
Expand All @@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer(

def abortTask(): Unit = {
try {
if (!writerClosed) {
if (writer != null) {
writer.close()
writerClosed = true
}
} finally {
super.abortTask()
Expand Down Expand Up @@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer(
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)

// If anything below fails, we should abort the task.
try {
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)

while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)

while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")

logInfo(s"Sorting complete. Writing out partition files one at a time.")
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val sortedIterator = sorter.sortedIterator()

val sortedIterator = sorter.sortedIterator()
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
var currentKey: UnsafeRow = null
var currentWriter: OutputWriter = null
try {
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter.writeInternal(sortedIterator.getValue)
currentWriter = newOutputWriter(currentKey, getPartitionString)
}
} finally {
if (currentWriter != null) { currentWriter.close() }
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}

commitTask()
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
Expand Down
Loading