Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
value being read. If None is set, it uses the default value,
``1000000``.
:param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will
log for each partition. Malformed records beyond this
number will be ignored. If None is set, it
uses the default value, ``10``.
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this maxMalformedLogPerPartition need to be added to L412, self._set_csv_opts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is right!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, my bad...

set, it uses the default value, ``PERMISSIVE``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* a record can have.</li>
* <li>`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed
* for any given value being read.</li>
* <li>`maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows
* Spark will log for each partition. Malformed records beyond this number will be ignored.</li>
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {

val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
tokenizedIterator.flatMap(parser(_).toSeq)
var numMalformedRecords = 0
tokenizedIterator.flatMap { recordTokens =>
val row = parser(recordTokens, numMalformedRecords)
if (row.isEmpty) {
numMalformedRecords += 1
}
row
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str

val escapeQuotes = getBool("escapeQuotes", true)

val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10)

val inputBufferSize = 128

val isCommentSet = this.comment != '\u0000'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,19 @@ object CSVRelation extends Logging {
}
}

/**
* Returns a function that parses a single CSV record (in the form of an array of strings in which
* each element represents a column) and turns it into either one resulting row or no row (if the
* the record is malformed).
*
* The 2nd argument in the returned function represents the total number of malformed rows
* observed so far.
*/
// This is pretty convoluted and we should probably rewrite the entire CSV parsing soon.
def csvParser(
schema: StructType,
requiredColumns: Array[String],
params: CSVOptions): Array[String] => Option[InternalRow] = {
params: CSVOptions): (Array[String], Int) => Option[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
Expand All @@ -72,9 +81,16 @@ object CSVRelation extends Logging {
val requiredSize = requiredFields.length
val row = new GenericMutableRow(requiredSize)

(tokens: Array[String]) => {
(tokens: Array[String], numMalformedRows) => {
if (params.dropMalformed && schemaFields.length != tokens.length) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
if (numMalformedRows < params.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
}
if (numMalformedRows == params.maxMalformedLogPerPartition - 1) {
logWarning(
s"More than ${params.maxMalformedLogPerPartition} malformed records have been " +
"found on this partition. Malformed records from now on will not be logged.")
}
None
} else if (params.failFast && schemaFields.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
Expand Down Expand Up @@ -109,23 +125,21 @@ object CSVRelation extends Logging {
Some(row)
} catch {
case NonFatal(e) if params.dropMalformed =>
logWarning("Parse exception. " +
s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
if (numMalformedRows < params.maxMalformedLogPerPartition) {
logWarning("Parse exception. " +
s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
}
if (numMalformedRows == params.maxMalformedLogPerPartition - 1) {
logWarning(
s"More than ${params.maxMalformedLogPerPartition} malformed records have been " +
"found on this partition. Malformed records from now on will not be logged.")
}
None
}
}
}
}

def parseCsv(
tokenizedRDD: RDD[Array[String]],
schema: StructType,
requiredColumns: Array[String],
options: CSVOptions): RDD[InternalRow] = {
val parser = csvParser(schema, requiredColumns, options)
tokenizedRDD.flatMap(parser(_).toSeq)
}

// Skips the header line of each file if the `header` option is set to true.
def dropHeaderLine(
file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {
Expand Down