diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 9208a527d29c3..01902f1470a09 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -351,6 +351,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
value being read. If None is set, it uses the default value,
``1000000``.
+ :param maxLogRecordsPerPartition: defines the maximum number of logs for the malformed
+ records that is going to be ignored. If None is set, it
+ uses the default value, ``1``.
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 88fa5cd21d58f..ec5e43805c78b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -394,6 +394,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* a record can have.
*
`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed
* for any given value being read.
+ * `maxLogRecordsPerPartition` (default `1`): defines the maximum number of logs for the
+ * malformed records that is going to be ignored.
* `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 4d36b760568cc..d10696f70e5ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -28,6 +28,8 @@ import org.apache.hadoop.mapreduce._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.JoinedRow
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -118,9 +120,20 @@ class CSVFileFormat extends FileFormat with DataSourceRegister {
CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
- val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
- val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
- tokenizedIterator.flatMap(parser(_).toSeq)
+ val unsafeRowIterator = {
+ val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
+ CSVRelation.parseCsvInIterator(tokenizedIterator, dataSchema, requiredSchema.fieldNames,
+ csvOptions)
+ }
+
+ // Appends partition values
+ val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 9f4ce8358b045..082c2b03f309e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -113,6 +113,8 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
val escapeQuotes = getBool("escapeQuotes", true)
+ val maxLogRecordsPerPartition = getInt("maxLogRecordsPerPartition", 1)
+
val inputBufferSize = 128
val isCommentSet = this.comment != '\u0000'
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index d72c8b9ac2e7c..3099124a7dfa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.types._
+import org.apache.spark.util.CompletionIterator
object CSVRelation extends Logging {
@@ -50,10 +51,11 @@ object CSVRelation extends Logging {
}
}
- def csvParser(
+ private def csvParser(
schema: StructType,
requiredColumns: Array[String],
- params: CSVOptions): Array[String] => Option[InternalRow] = {
+ params: CSVOptions,
+ malformedLinesInfo: MalformedLinesInfo): Array[String] => Option[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
@@ -74,12 +76,16 @@ object CSVRelation extends Logging {
(tokens: Array[String]) => {
if (params.dropMalformed && schemaFields.length != tokens.length) {
- logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ val line = tokens.mkString(params.delimiter.toString)
+ malformedLinesInfo.add(line)
None
} else if (params.failFast && schemaFields.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(params.delimiter.toString)}")
} else {
+ if (schemaFields.length != tokens.length) {
+ malformedLinesInfo.add(tokens.mkString(params.delimiter.toString))
+ }
val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) {
tokens ++ new Array[String](schemaFields.length - tokens.length)
} else if (params.permissive && schemaFields.length < tokens.length) {
@@ -109,21 +115,44 @@ object CSVRelation extends Logging {
Some(row)
} catch {
case NonFatal(e) if params.dropMalformed =>
- logWarning("Parse exception. " +
- s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ val line = tokens.mkString(params.delimiter.toString)
+ malformedLinesInfo.add(line)
None
}
}
}
}
- def parseCsv(
+ private[csv] def parseCsvInRdd(
tokenizedRDD: RDD[Array[String]],
schema: StructType,
requiredColumns: Array[String],
options: CSVOptions): RDD[InternalRow] = {
- val parser = csvParser(schema, requiredColumns, options)
- tokenizedRDD.flatMap(parser(_).toSeq)
+ val malformedLinesInfo = new MalformedLinesInfo(options.maxLogRecordsPerPartition)
+ val parser = csvParser(schema, requiredColumns, options, malformedLinesInfo)
+ tokenizedRDD.mapPartitions { iter =>
+ val rows = iter.flatMap(parser(_).toSeq)
+ CompletionIterator[InternalRow, Iterator[InternalRow]](rows, {
+ if (malformedLinesInfo.malformedLineNum > 0) {
+ logWarning(s"$malformedLinesInfo")
+ }
+ })
+ }
+ }
+
+ private[csv] def parseCsvInIterator(
+ tokenizedIterator: Iterator[Array[String]],
+ schema: StructType,
+ requiredColumns: Array[String],
+ options: CSVOptions): Iterator[InternalRow] = {
+ val malformedLinesInfo = new MalformedLinesInfo(options.maxLogRecordsPerPartition)
+ val parser = csvParser(schema, requiredColumns, options, malformedLinesInfo)
+ val rows = tokenizedIterator.flatMap(parser(_).toSeq)
+ CompletionIterator[InternalRow, Iterator[InternalRow]](rows, {
+ if (malformedLinesInfo.malformedLineNum > 0) {
+ logWarning(s"$malformedLinesInfo")
+ }
+ })
}
// Skips the header line of each file if the `header` option is set to true.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtils.scala
new file mode 100644
index 0000000000000..09c9d428be7c8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.internal.Logging
+
+/**
+ * Logs and counts malformed lines during CSV parsing.
+ */
+private[csv] class MalformedLinesInfo(maxStoreMalformed: Int) extends Serializable with Logging {
+
+ var malformedLineNum = 0
+
+ def add(line: String): Unit = {
+ if (malformedLineNum < maxStoreMalformed) {
+ logWarning(s"Parse exception. Dropping malformed line: ${line}")
+ }
+ malformedLineNum = malformedLineNum + 1
+ }
+
+ override def toString: String = {
+ s"# of total malformed lines: ${malformedLineNum}"
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtilsSuite.scala
new file mode 100644
index 0000000000000..5410230bc0b33
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CsvUtilsSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+
+class CsvUtilsSuite extends SparkFunSuite {
+
+ test("count malformed lines") {
+ val malformedLinesInfo = new MalformedLinesInfo(3)
+ malformedLinesInfo.add("aaa, bbb, ccc")
+ malformedLinesInfo.add("ddd, eee")
+ malformedLinesInfo.add("fff, ggg, hhh, iii")
+ malformedLinesInfo.add("jjj")
+ assert(s"${malformedLinesInfo}" ===
+ s"""
+ |# of total malformed lines: 4
+ """.stripMargin.trim)
+ }
+}