Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ package org.apache.spark.sql.execution.datasources

import scala.collection.mutable

import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.TaskOutputFileAlreadyExistException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -76,14 +79,28 @@ abstract class FileFormatDataWriter(
releaseCurrentWriter()
}

private def enrichWriteError[T](path: => String)(f: => T): T = try {
f
} catch {
case e: FetchFailedException =>
throw e
case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput =>
// If any output file to write already exists, it does not make sense to re-run this task.
// We throw the exception and let Executor throw ExceptionFailure to abort the job.
throw new TaskOutputFileAlreadyExistException(f)
case t: Throwable => throw QueryExecutionErrors.taskFailedWhileWritingRowsError(path, t)
}

/** Writes a record. */
def write(record: InternalRow): Unit

final def writeWithMetrics(record: InternalRow, count: Long): Unit = {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics)
}
write(record)
enrichWriteError(Option(currentWriter).map(_.path()).getOrElse(description.path)) {
write(record)
}
}

/** Write an iterator of records. */
Expand All @@ -102,7 +119,7 @@ abstract class FileFormatDataWriter(
* to the driver and used to update the catalog. Other information will be sent back to the
* driver too and used to e.g. update the metrics in UI.
*/
override def commit(): WriteTaskResult = {
final override def commit(): WriteTaskResult = enrichWriteError(description.path) {
releaseResources()
val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs {
committer.commitTask(taskAttemptContext)
Expand All @@ -113,15 +130,15 @@ abstract class FileFormatDataWriter(
WriteTaskResult(taskCommitMessage, summary)
}

def abort(): Unit = {
final def abort(): Unit = enrichWriteError(description.path) {
try {
releaseResources()
} finally {
committer.abortTask(taskAttemptContext)
}
}

override def close(): Unit = {}
final override def close(): Unit = {}
}

/** FileFormatWriteTask for empty partitions */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
Expand All @@ -37,11 +36,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.{NextIterator, SerializableConfiguration, Utils}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.ArrayImplicits._


Expand Down Expand Up @@ -400,31 +397,17 @@ object FileFormatWriter extends Logging {
}
}

try {
val queryFailureCapturedIterator = new QueryFailureCapturedIterator(iterator)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out and commit the task.
dataWriter.writeWithIterator(queryFailureCapturedIterator)
dataWriter.commit()
})(catchBlock = {
// If there is an error, abort the task
dataWriter.abort()
logError(s"Job $jobId aborted.")
}, finallyBlock = {
dataWriter.close()
})
} catch {
case e: QueryFailureDuringWrite =>
throw e.queryFailure
case e: FetchFailedException =>
throw e
case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput =>
// If any output file to write already exists, it does not make sense to re-run this task.
// We throw the exception and let Executor throw ExceptionFailure to abort the job.
throw new TaskOutputFileAlreadyExistException(f)
case t: Throwable =>
throw QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t)
}
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out and commit the task.
dataWriter.writeWithIterator(iterator)
dataWriter.commit()
})(catchBlock = {
// If there is an error, abort the task
dataWriter.abort()
logError(s"Job $jobId aborted.")
}, finallyBlock = {
dataWriter.close()
})
}

/**
Expand Down Expand Up @@ -455,25 +438,3 @@ object FileFormatWriter extends Logging {
}
}
}

// A exception wrapper to indicate that the error was thrown when executing the query, not writing
// the data
private class QueryFailureDuringWrite(val queryFailure: Throwable) extends Throwable

// An iterator wrapper to rethrow any error from the given iterator with `QueryFailureDuringWrite`.
private class QueryFailureCapturedIterator(data: Iterator[InternalRow])
extends NextIterator[InternalRow] {

override protected def getNext(): InternalRow = try {
if (data.hasNext) {
data.next()
} else {
finished = true
null
}
} catch {
case t: Throwable => throw new QueryFailureDuringWrite(t)
}

override protected def close(): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1250,10 +1250,10 @@ abstract class CSVSuite
val ex = intercept[SparkException] {
exp.write.format("csv").option("timestampNTZFormat", pattern).save(path.getAbsolutePath)
}
checkError(
checkErrorMatchPVals(
exception = ex,
errorClass = "TASK_WRITE_FAILED",
parameters = Map("path" -> actualPath))
parameters = Map("path" -> s"$actualPath.*"))
val msg = ex.getCause.getMessage
assert(
msg.contains("Unsupported field: OffsetSeconds") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3043,10 +3043,10 @@ abstract class JsonSuite
val err = intercept[SparkException] {
exp.write.option("timestampNTZFormat", pattern).json(path.getAbsolutePath)
}
checkError(
checkErrorMatchPVals(
exception = err,
errorClass = "TASK_WRITE_FAILED",
parameters = Map("path" -> actualPath))
parameters = Map("path" -> s"$actualPath.*"))

val msg = err.getCause.getMessage
assert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import scala.collection.mutable
import scala.io.Source
import scala.jdk.CollectionConverters._

import org.apache.commons.lang3.StringUtils
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataInputStream
Expand Down Expand Up @@ -2451,10 +2450,10 @@ class XmlSuite
exp.write.option("timestampNTZFormat", pattern)
.option("rowTag", "ROW").xml(path.getAbsolutePath)
}
checkError(
checkErrorMatchPVals(
exception = err,
errorClass = "TASK_WRITE_FAILED",
parameters = Map("path" -> actualPath))
parameters = Map("path" -> s"$actualPath.*"))
val msg = err.getCause.getMessage
assert(
msg.contains("Unsupported field: OffsetSeconds") ||
Expand Down Expand Up @@ -2948,11 +2947,11 @@ class XmlSuite
.mode(SaveMode.Overwrite)
.xml(path)
}
val actualPath = Path.of(dir.getAbsolutePath).toUri.toURL.toString
checkError(
val actualPath = Path.of(dir.getAbsolutePath).toUri.toURL.toString.stripSuffix("/")
checkErrorMatchPVals(
exception = e,
errorClass = "TASK_WRITE_FAILED",
parameters = Map("path" -> StringUtils.removeEnd(actualPath, "/")))
parameters = Map("path" -> s"$actualPath.*"))
assert(e.getCause.isInstanceOf[XMLStreamException])
assert(e.getCause.getMessage.contains(errorMsg))
}
Expand Down