Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
value being read. If None is set, it uses the default value,
``1000000``.
:param maxLogRecordsPerPartition: defines the maximum number of logs for the malformed
records that is going to be ignored. If None is set, it
uses the default value, ``1``.
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* a record can have.</li>
* <li>`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed
* for any given value being read.</li>
* <li>`maxLogRecordsPerPartition` (default `1`): defines the maximum number of logs for the
* malformed records that is going to be ignored.</li>
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.hadoop.mapreduce._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -118,9 +120,20 @@ class CSVFileFormat extends FileFormat with DataSourceRegister {

CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)

val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
tokenizedIterator.flatMap(parser(_).toSeq)
val unsafeRowIterator = {
val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
CSVRelation.parseCsvInIterator(tokenizedIterator, dataSchema, requiredSchema.fieldNames,
csvOptions)
}

// Appends partition values
val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to the issue you are fixing?

val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)

unsafeRowIterator.map { dataRow =>
appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str

val escapeQuotes = getBool("escapeQuotes", true)

val maxLogRecordsPerPartition = getInt("maxLogRecordsPerPartition", 1)

val inputBufferSize = 128

val isCommentSet = this.comment != '\u0000'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator

object CSVRelation extends Logging {

Expand All @@ -50,10 +51,11 @@ object CSVRelation extends Logging {
}
}

def csvParser(
private def csvParser(
schema: StructType,
requiredColumns: Array[String],
params: CSVOptions): Array[String] => Option[InternalRow] = {
params: CSVOptions,
malformedLinesInfo: MalformedLinesInfo): Array[String] => Option[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
Expand All @@ -74,12 +76,16 @@ object CSVRelation extends Logging {

(tokens: Array[String]) => {
if (params.dropMalformed && schemaFields.length != tokens.length) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
val line = tokens.mkString(params.delimiter.toString)
malformedLinesInfo.add(line)
None
} else if (params.failFast && schemaFields.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(params.delimiter.toString)}")
} else {
if (schemaFields.length != tokens.length) {
malformedLinesInfo.add(tokens.mkString(params.delimiter.toString))
}
val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) {
tokens ++ new Array[String](schemaFields.length - tokens.length)
} else if (params.permissive && schemaFields.length < tokens.length) {
Expand Down Expand Up @@ -109,21 +115,44 @@ object CSVRelation extends Logging {
Some(row)
} catch {
case NonFatal(e) if params.dropMalformed =>
logWarning("Parse exception. " +
s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
val line = tokens.mkString(params.delimiter.toString)
malformedLinesInfo.add(line)
None
}
}
}
}

def parseCsv(
private[csv] def parseCsvInRdd(
tokenizedRDD: RDD[Array[String]],
schema: StructType,
requiredColumns: Array[String],
options: CSVOptions): RDD[InternalRow] = {
val parser = csvParser(schema, requiredColumns, options)
tokenizedRDD.flatMap(parser(_).toSeq)
val malformedLinesInfo = new MalformedLinesInfo(options.maxLogRecordsPerPartition)
val parser = csvParser(schema, requiredColumns, options, malformedLinesInfo)
tokenizedRDD.mapPartitions { iter =>
val rows = iter.flatMap(parser(_).toSeq)
CompletionIterator[InternalRow, Iterator[InternalRow]](rows, {
if (malformedLinesInfo.malformedLineNum > 0) {
logWarning(s"$malformedLinesInfo")
}
})
}
}

private[csv] def parseCsvInIterator(
tokenizedIterator: Iterator[Array[String]],
schema: StructType,
requiredColumns: Array[String],
options: CSVOptions): Iterator[InternalRow] = {
val malformedLinesInfo = new MalformedLinesInfo(options.maxLogRecordsPerPartition)
val parser = csvParser(schema, requiredColumns, options, malformedLinesInfo)
val rows = tokenizedIterator.flatMap(parser(_).toSeq)
CompletionIterator[InternalRow, Iterator[InternalRow]](rows, {
if (malformedLinesInfo.malformedLineNum > 0) {
logWarning(s"$malformedLinesInfo")
}
})
}

// Skips the header line of each file if the `header` option is set to true.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv

import org.apache.spark.internal.Logging

/**
* Logs and counts malformed lines during CSV parsing.
*/
private[csv] class MalformedLinesInfo(maxStoreMalformed: Int) extends Serializable with Logging {

var malformedLineNum = 0

def add(line: String): Unit = {
if (malformedLineNum < maxStoreMalformed) {
logWarning(s"Parse exception. Dropping malformed line: ${line}")
}
malformedLineNum = malformedLineNum + 1
}

override def toString: String = {
s"# of total malformed lines: ${malformedLineNum}"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv

import org.apache.spark.SparkFunSuite

class CsvUtilsSuite extends SparkFunSuite {

test("count malformed lines") {
val malformedLinesInfo = new MalformedLinesInfo(3)
malformedLinesInfo.add("aaa, bbb, ccc")
malformedLinesInfo.add("ddd, eee")
malformedLinesInfo.add("fff, ggg, hhh, iii")
malformedLinesInfo.add("jjj")
assert(s"${malformedLinesInfo}" ===
s"""
|# of total malformed lines: 4
""".stripMargin.trim)
}
}