From 703bcd6fc2df484cbdeb6192371191bbc31813fa Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Apr 2024 17:22:39 +0800 Subject: [PATCH] more accurate file path in TASK_WRITE_FAILED error --- .../datasources/FileFormatDataWriter.scala | 27 ++++++-- .../datasources/FileFormatWriter.scala | 65 ++++--------------- .../execution/datasources/csv/CSVSuite.scala | 4 +- .../datasources/json/JsonSuite.scala | 4 +- .../execution/datasources/xml/XmlSuite.scala | 11 ++-- 5 files changed, 44 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index b9e8475e4859..1dbb6ce26f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -18,17 +18,20 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileAlreadyExistsException, Path} import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TaskOutputFileAlreadyExistException import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.internal.SQLConf @@ -76,6 +79,18 @@ abstract class FileFormatDataWriter( releaseCurrentWriter() } + private def enrichWriteError[T](path: => String)(f: => T): T = try { + f + } catch { + case e: FetchFailedException => + throw e + case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput => + // If any output file to write already exists, it does not make sense to re-run this task. + // We throw the exception and let Executor throw ExceptionFailure to abort the job. + throw new TaskOutputFileAlreadyExistException(f) + case t: Throwable => throw QueryExecutionErrors.taskFailedWhileWritingRowsError(path, t) + } + /** Writes a record. */ def write(record: InternalRow): Unit @@ -83,7 +98,9 @@ abstract class FileFormatDataWriter( if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics) } - write(record) + enrichWriteError(Option(currentWriter).map(_.path()).getOrElse(description.path)) { + write(record) + } } /** Write an iterator of records. */ @@ -102,7 +119,7 @@ abstract class FileFormatDataWriter( * to the driver and used to update the catalog. Other information will be sent back to the * driver too and used to e.g. update the metrics in UI. */ - override def commit(): WriteTaskResult = { + final override def commit(): WriteTaskResult = enrichWriteError(description.path) { releaseResources() val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs { committer.commitTask(taskAttemptContext) @@ -113,7 +130,7 @@ abstract class FileFormatDataWriter( WriteTaskResult(taskCommitMessage, summary) } - def abort(): Unit = { + final def abort(): Unit = enrichWriteError(description.path) { try { releaseResources() } finally { @@ -121,7 +138,7 @@ abstract class FileFormatDataWriter( } } - override def close(): Unit = {} + final override def close(): Unit = {} } /** FileFormatWriteTask for empty partitions */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1df63aa14b4b..3bfa3413f679 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileAlreadyExistsException, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -28,7 +28,6 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -37,11 +36,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.connector.write.WriterCommitMessage -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{NextIterator, SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -400,31 +397,17 @@ object FileFormatWriter extends Logging { } } - try { - val queryFailureCapturedIterator = new QueryFailureCapturedIterator(iterator) - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - // Execute the task to write rows out and commit the task. - dataWriter.writeWithIterator(queryFailureCapturedIterator) - dataWriter.commit() - })(catchBlock = { - // If there is an error, abort the task - dataWriter.abort() - logError(s"Job $jobId aborted.") - }, finallyBlock = { - dataWriter.close() - }) - } catch { - case e: QueryFailureDuringWrite => - throw e.queryFailure - case e: FetchFailedException => - throw e - case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput => - // If any output file to write already exists, it does not make sense to re-run this task. - // We throw the exception and let Executor throw ExceptionFailure to abort the job. - throw new TaskOutputFileAlreadyExistException(f) - case t: Throwable => - throw QueryExecutionErrors.taskFailedWhileWritingRowsError(description.path, t) - } + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + // Execute the task to write rows out and commit the task. + dataWriter.writeWithIterator(iterator) + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") + }, finallyBlock = { + dataWriter.close() + }) } /** @@ -455,25 +438,3 @@ object FileFormatWriter extends Logging { } } } - -// A exception wrapper to indicate that the error was thrown when executing the query, not writing -// the data -private class QueryFailureDuringWrite(val queryFailure: Throwable) extends Throwable - -// An iterator wrapper to rethrow any error from the given iterator with `QueryFailureDuringWrite`. -private class QueryFailureCapturedIterator(data: Iterator[InternalRow]) - extends NextIterator[InternalRow] { - - override protected def getNext(): InternalRow = try { - if (data.hasNext) { - data.next() - } else { - finished = true - null - } - } catch { - case t: Throwable => throw new QueryFailureDuringWrite(t) - } - - override protected def close(): Unit = {} -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 964d1ec85e15..22ea133ee19a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1250,10 +1250,10 @@ abstract class CSVSuite val ex = intercept[SparkException] { exp.write.format("csv").option("timestampNTZFormat", pattern).save(path.getAbsolutePath) } - checkError( + checkErrorMatchPVals( exception = ex, errorClass = "TASK_WRITE_FAILED", - parameters = Map("path" -> actualPath)) + parameters = Map("path" -> s"$actualPath.*")) val msg = ex.getCause.getMessage assert( msg.contains("Unsupported field: OffsetSeconds") || diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9af7511ca913..5c96df98dd23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3043,10 +3043,10 @@ abstract class JsonSuite val err = intercept[SparkException] { exp.write.option("timestampNTZFormat", pattern).json(path.getAbsolutePath) } - checkError( + checkErrorMatchPVals( exception = err, errorClass = "TASK_WRITE_FAILED", - parameters = Map("path" -> actualPath)) + parameters = Map("path" -> s"$actualPath.*")) val msg = err.getCause.getMessage assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 9dedd5795370..ddb49657144d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -30,7 +30,6 @@ import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ -import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataInputStream @@ -2451,10 +2450,10 @@ class XmlSuite exp.write.option("timestampNTZFormat", pattern) .option("rowTag", "ROW").xml(path.getAbsolutePath) } - checkError( + checkErrorMatchPVals( exception = err, errorClass = "TASK_WRITE_FAILED", - parameters = Map("path" -> actualPath)) + parameters = Map("path" -> s"$actualPath.*")) val msg = err.getCause.getMessage assert( msg.contains("Unsupported field: OffsetSeconds") || @@ -2948,11 +2947,11 @@ class XmlSuite .mode(SaveMode.Overwrite) .xml(path) } - val actualPath = Path.of(dir.getAbsolutePath).toUri.toURL.toString - checkError( + val actualPath = Path.of(dir.getAbsolutePath).toUri.toURL.toString.stripSuffix("/") + checkErrorMatchPVals( exception = e, errorClass = "TASK_WRITE_FAILED", - parameters = Map("path" -> StringUtils.removeEnd(actualPath, "/"))) + parameters = Map("path" -> s"$actualPath.*")) assert(e.getCause.isInstanceOf[XMLStreamException]) assert(e.getCause.getMessage.contains(errorMsg)) }