From adfde77125eb31b262a2f010851beef2b872e1e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 Mar 2017 21:52:59 +0800 Subject: [PATCH 1/6] unify bad record handling in CSV and JSON --- .../expressions/jsonExpressions.scala | 4 +- .../spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/catalyst/json/JacksonParser.scala | 121 +-------- .../sql/catalyst/util/FailureSafeParser.scala | 250 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 21 +- .../datasources/csv/CSVDataSource.scala | 17 +- .../datasources/csv/CSVFileFormat.scala | 8 +- .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 195 +++++--------- .../datasources/json/JsonDataSource.scala | 29 +- .../datasources/json/JsonFileFormat.scala | 12 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- .../datasources/json/JsonSuite.scala | 8 +- 13 files changed, 389 insertions(+), 282 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b5f2f7ed2e..e572ed8a6912 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -555,7 +555,7 @@ case class JsonToStruct( CreateJacksonParser.utf8String, identity[UTF8String])) } catch { - case _: SparkSQLJsonProcessingException => null + case _: BadRecordException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f222ec602c9..355c26afa6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,7 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c9..a97ce1d3413d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -32,17 +32,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) - /** * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( schema: StructType, - options: JSONOptions) extends Logging { + val options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,107 +52,7 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } + private val emptyRow = new GenericInternalRow(schema.length) /** * Create a converter which converts the JSON documents held by the `JsonParser` @@ -239,7 +136,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toFloat } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + throw new RuntimeException(s"Cannot parse $value as FloatType.") } } @@ -259,7 +156,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toDouble } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + throw new RuntimeException(s"Cannot parse $value as DoubleType.") } } @@ -391,9 +288,9 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLJsonProcessingException and this exception will be caught by + // SparkSQLRuntimeException and this exception will be caught by // `parse` method. - throw new SparkSQLJsonProcessingException( + throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -466,14 +363,14 @@ class JacksonParser( parser.nextToken() match { case null => Nil case _ => rootConverter.apply(parser) match { - case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case null => throw new RuntimeException("Root converter returned null") case rows => rows } } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: RuntimeException | _: JsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => emptyRow, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala new file mode 100644 index 000000000000..be4708dd713d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, Decimal, StringType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class FailureSafeParser[IN]( + func: IN => Seq[InternalRow], + mode: String, + corruptFieldIndex: Option[Int]) { + + private val toResultRow: (InternalRow, () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + val resultRow = new RowWithBadRecord(null, corruptFieldIndex.get, null) + (row, badRecord) => { + resultRow.row = row + resultRow.record = badRecord() + resultRow + } + } else { + (row, badRecord) => row + } + } + + def parse(input: IN): Seq[InternalRow] = { + try { + func(input).map(toResultRow(_, () => null)) + } catch { + case e: BadRecordException if ParseModes.isPermissiveMode(mode) => + Seq(toResultRow(e.partialResult(), e.record)) + case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => + Nil + // If the parse mode is FAIL FAST, do not catch the exception. + } + } +} + +case class BadRecordException( + record: () => UTF8String, + partialResult: () => InternalRow, + cause: Throwable) extends Exception(cause) + +class RowWithBadRecord(var row: InternalRow, index: Int, var record: UTF8String) + extends InternalRow { + + override def numFields: Int = row.numFields + 1 + + override def setNullAt(ordinal: Int): Unit = { + if (ordinal < index) { + row.setNullAt(ordinal) + } else if (ordinal == index) { + record = null + } else { + row.setNullAt(ordinal - 1) + } + } + + override def update(i: Int, value: Any): Unit = { + throw new UnsupportedOperationException("update") + } + + override def copy(): InternalRow = new RowWithBadRecord(row.copy(), index, record) + + override def anyNull: Boolean = row.anyNull || record == null + + override def isNullAt(ordinal: Int): Boolean = { + if (ordinal < index) { + row.isNullAt(ordinal) + } else if (ordinal == index) { + record == null + } else { + row.isNullAt(ordinal - 1) + } + } + + private def fail() = { + throw new IllegalAccessError("This is a string field.") + } + + override def getBoolean(ordinal: Int): Boolean = { + if (ordinal < index) { + row.getBoolean(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getBoolean(ordinal - 1) + } + } + + override def getByte(ordinal: Int): Byte = { + if (ordinal < index) { + row.getByte(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getByte(ordinal - 1) + } + } + + override def getShort(ordinal: Int): Short = { + if (ordinal < index) { + row.getShort(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getShort(ordinal - 1) + } + } + + override def getInt(ordinal: Int): Int = { + if (ordinal < index) { + row.getInt(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getInt(ordinal - 1) + } + } + + override def getLong(ordinal: Int): Long = { + if (ordinal < index) { + row.getLong(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getLong(ordinal - 1) + } + } + + override def getFloat(ordinal: Int): Float = { + if (ordinal < index) { + row.getFloat(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getFloat(ordinal - 1) + } + } + + override def getDouble(ordinal: Int): Double = { + if (ordinal < index) { + row.getDouble(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getDouble(ordinal - 1) + } + } + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = { + if (ordinal < index) { + row.getDecimal(ordinal, precision, scale) + } else if (ordinal == index) { + fail() + } else { + row.getDecimal(ordinal - 1, precision, scale) + } + } + + override def getUTF8String(ordinal: Int): UTF8String = { + if (ordinal < index) { + row.getUTF8String(ordinal) + } else if (ordinal == index) { + record + } else { + row.getUTF8String(ordinal - 1) + } + } + + override def getBinary(ordinal: Int): Array[Byte] = { + if (ordinal < index) { + row.getBinary(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getBinary(ordinal - 1) + } + } + + override def getInterval(ordinal: Int): CalendarInterval = { + if (ordinal < index) { + row.getInterval(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getInterval(ordinal - 1) + } + } + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + if (ordinal < index) { + row.getStruct(ordinal, numFields) + } else if (ordinal == index) { + fail() + } else { + row.getStruct(ordinal - 1, numFields) + } + } + + override def getArray(ordinal: Int): ArrayData = { + if (ordinal < index) { + row.getArray(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getArray(ordinal - 1) + } + } + + override def getMap(ordinal: Int): MapData = { + if (ordinal < index) { + row.getMap(ordinal) + } else if (ordinal == index) { + fail() + } else { + row.getMap(ordinal - 1) + } + } + + override def get(ordinal: Int, dataType: DataType): AnyRef = { + if (ordinal < index) { + row.get(ordinal, dataType) + } else if (ordinal == index) { + if (dataType == StringType) { + record + } else { + fail() + } + } else { + row.get(ordinal - 1, dataType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 88fbfb4c92a0..3ba7101f8ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ @@ -382,11 +383,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val corruptFieldIndex = schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + val rawParser = new JacksonParser(dataSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => rawParser.parse(input, createParser, UTF8String.fromString), + parsedOptions.parseMode, + corruptFieldIndex) + iter.flatMap(parser.parse) } Dataset.ofRows( @@ -435,14 +442,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val corruptFieldIndex = schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => - val parser = new UnivocityParser(schema, parsedOptions) - iter.flatMap(line => parser.parse(line)) + val rawParser = new UnivocityParser(dataSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + parsedOptions.parseMode, + corruptFieldIndex) + iter.flatMap(parser.parse) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 35ff924f27ce..2576ab714911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] + corruptFieldIndex: Option[Int]): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + val shouldDropHeader = parser.options.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, corruptFieldIndex) } override def infer( @@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - parsedOptions.headerFlag, - parser) + parser.options.headerFlag, + parser, + corruptFieldIndex) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e..38197d348168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -102,6 +102,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val corruptFieldIndex = requiredSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -113,8 +114,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value - val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, corruptFieldIndex) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2632e87971d6..f6c6b6f56cd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,7 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f..7fea2ed68e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,14 +30,14 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( schema: StructType, requiredSchema: StructType, - private val options: CSVOptions) extends Logging { + val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -46,39 +46,28 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. - private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + private val emptyRow = new GenericInternalRow(requiredSchema.length) - // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // Retrieve the raw record string. + private def getCurrentInput(): UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } + + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] - // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] // // with the input tokens, // @@ -86,45 +75,12 @@ class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] - // - // In more details, - // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each - // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. - // - // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +161,7 @@ class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,81 +189,39 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) - - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else { + tokens.take(schema.length) } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) + def getPartialResult(): InternalRow = { + try { + convert(checkedTokens) + } catch { + case _: BadRecordException => emptyRow } - } else { - tokens } - + throw BadRecordException( + () => getCurrentInput(), + getPartialResult, + new RuntimeException("Malformed CSV record")) + } else { try { - Some(convert(checkedTokens)) + for (i <- requiredSchema.indices) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) + } + row } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + case NonFatal(e) => + throw BadRecordException(() => getCurrentInput(), () => emptyRow, e) } } } @@ -331,10 +245,15 @@ private[csv] object UnivocityParser { def parseStream( inputStream: InputStream, shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + corruptFieldIndex) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => - parser.convert(tokens) + safeParser.parse(tokens) }.flatten } @@ -368,7 +287,8 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -381,6 +301,11 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + corruptFieldIndex) + filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 84f026620d90..55bb33adf46b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.io.InputStream + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -31,6 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -49,7 +52,8 @@ abstract class JsonDataSource extends Serializable { def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] + parser: JacksonParser, + corruptFieldIndex: Option[Int]): Iterator[InternalRow] final def inferSchema( sparkSession: SparkSession, @@ -127,10 +131,15 @@ object TextInputJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) + val safeParser = new FailureSafeParser[Text]( + input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + parser.options.parseMode, + corruptFieldIndex) + linesReader.flatMap(safeParser.parse) } private def textToUTF8String(value: Text): UTF8String = { @@ -180,7 +189,8 @@ object WholeFileJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -189,9 +199,12 @@ object WholeFileJsonDataSource extends JsonDataSource { } } - parser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - CreateJacksonParser.inputStream, - partitionedFileString).toIterator + val safeParser = new FailureSafeParser[InputStream]( + input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + parser.options.parseMode, + corruptFieldIndex) + + safeParser.parse( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)).toIterator } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9dd91eba6f7..a99f85cc8c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -102,9 +102,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val corruptFieldIndex = requiredSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) + corruptFieldIndex.foreach { i => + val f = dataSchema(i) if (f.dataType != StringType || !f.nullable) { throw new AnalysisException( "The field for corrupt records must be string type and nullable") @@ -112,11 +115,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(requiredSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, - parser) + parser, + corruptFieldIndex) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e..598babfe0e7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV record")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9b0efcbdaf5c..8c10cef56806 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionOne.getMessage.contains("BadRecordException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionTwo.getMessage.contains("BadRecordException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionOne.getMessage.contains("BadRecordException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionTwo.getMessage.contains("BadRecordException")) } } From 6326c9dd530a59dbe771acfa1a1860a8e8b13e08 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 17 Mar 2017 13:52:53 +0900 Subject: [PATCH 2/6] Fix from_json test in R to check NA not the exception --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f7081cb1d4e5..dc3fe6f6c95d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1354,9 +1354,8 @@ test_that("column functions", { # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) - expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), - error = function(e) { stop(e) }), - paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") From b5aee0e0afc21f6006455d614891234f50835457 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Mar 2017 18:32:34 +0800 Subject: [PATCH 3/6] improve --- .../sql/catalyst/json/JacksonParser.scala | 7 +- .../sql/catalyst/util/FailureSafeParser.scala | 228 ++---------------- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../datasources/csv/CSVDataSource.scala | 10 +- .../datasources/csv/CSVFileFormat.scala | 3 +- .../datasources/csv/UnivocityParser.scala | 20 +- .../datasources/json/JsonDataSource.scala | 14 +- .../datasources/json/JsonFileFormat.scala | 7 +- 8 files changed, 56 insertions(+), 241 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a97ce1d3413d..fdb7d88d5bd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -52,8 +52,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow = new GenericInternalRow(schema.length) - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -288,8 +286,7 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLRuntimeException and this exception will be caught by - // `parse` method. + // RuntimeException and this exception will be caught by `parse` method. throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -370,7 +367,7 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => - throw BadRecordException(() => recordLiteral(record), () => emptyRow, e) + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index be4708dd713d..2cd320636221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -18,35 +18,45 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, Decimal, StringType} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( func: IN => Seq[InternalRow], mode: String, - corruptFieldIndex: Option[Int]) { + schema: StructType, + columnNameOfCorruptRecord: String) { - private val toResultRow: (InternalRow, () => UTF8String) => InternalRow = { + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { if (corruptFieldIndex.isDefined) { - val resultRow = new RowWithBadRecord(null, corruptFieldIndex.get, null) (row, badRecord) => { - resultRow.row = row - resultRow.record = badRecord() + for ((f, i) <- actualSchema.zipWithIndex) { + resultRow(schema.fieldIndex(f.name)) = row.map(_.get(i, f.dataType)).orNull + } + resultRow(corruptFieldIndex.get) = badRecord() resultRow } } else { - (row, badRecord) => row + (row, badRecord) => row.getOrElse { + for (i <- schema.indices) resultRow.setNullAt(i) + resultRow + } } } - def parse(input: IN): Seq[InternalRow] = { + def parse(input: IN): Iterator[InternalRow] = { try { - func(input).map(toResultRow(_, () => null)) + func(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { case e: BadRecordException if ParseModes.isPermissiveMode(mode) => - Seq(toResultRow(e.partialResult(), e.record)) + Iterator(toResultRow(e.partialResult(), e.record)) case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => - Nil + Iterator.empty // If the parse mode is FAIL FAST, do not catch the exception. } } @@ -54,197 +64,5 @@ class FailureSafeParser[IN]( case class BadRecordException( record: () => UTF8String, - partialResult: () => InternalRow, + partialResult: () => Option[InternalRow], cause: Throwable) extends Exception(cause) - -class RowWithBadRecord(var row: InternalRow, index: Int, var record: UTF8String) - extends InternalRow { - - override def numFields: Int = row.numFields + 1 - - override def setNullAt(ordinal: Int): Unit = { - if (ordinal < index) { - row.setNullAt(ordinal) - } else if (ordinal == index) { - record = null - } else { - row.setNullAt(ordinal - 1) - } - } - - override def update(i: Int, value: Any): Unit = { - throw new UnsupportedOperationException("update") - } - - override def copy(): InternalRow = new RowWithBadRecord(row.copy(), index, record) - - override def anyNull: Boolean = row.anyNull || record == null - - override def isNullAt(ordinal: Int): Boolean = { - if (ordinal < index) { - row.isNullAt(ordinal) - } else if (ordinal == index) { - record == null - } else { - row.isNullAt(ordinal - 1) - } - } - - private def fail() = { - throw new IllegalAccessError("This is a string field.") - } - - override def getBoolean(ordinal: Int): Boolean = { - if (ordinal < index) { - row.getBoolean(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getBoolean(ordinal - 1) - } - } - - override def getByte(ordinal: Int): Byte = { - if (ordinal < index) { - row.getByte(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getByte(ordinal - 1) - } - } - - override def getShort(ordinal: Int): Short = { - if (ordinal < index) { - row.getShort(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getShort(ordinal - 1) - } - } - - override def getInt(ordinal: Int): Int = { - if (ordinal < index) { - row.getInt(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getInt(ordinal - 1) - } - } - - override def getLong(ordinal: Int): Long = { - if (ordinal < index) { - row.getLong(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getLong(ordinal - 1) - } - } - - override def getFloat(ordinal: Int): Float = { - if (ordinal < index) { - row.getFloat(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getFloat(ordinal - 1) - } - } - - override def getDouble(ordinal: Int): Double = { - if (ordinal < index) { - row.getDouble(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getDouble(ordinal - 1) - } - } - - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = { - if (ordinal < index) { - row.getDecimal(ordinal, precision, scale) - } else if (ordinal == index) { - fail() - } else { - row.getDecimal(ordinal - 1, precision, scale) - } - } - - override def getUTF8String(ordinal: Int): UTF8String = { - if (ordinal < index) { - row.getUTF8String(ordinal) - } else if (ordinal == index) { - record - } else { - row.getUTF8String(ordinal - 1) - } - } - - override def getBinary(ordinal: Int): Array[Byte] = { - if (ordinal < index) { - row.getBinary(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getBinary(ordinal - 1) - } - } - - override def getInterval(ordinal: Int): CalendarInterval = { - if (ordinal < index) { - row.getInterval(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getInterval(ordinal - 1) - } - } - - override def getStruct(ordinal: Int, numFields: Int): InternalRow = { - if (ordinal < index) { - row.getStruct(ordinal, numFields) - } else if (ordinal == index) { - fail() - } else { - row.getStruct(ordinal - 1, numFields) - } - } - - override def getArray(ordinal: Int): ArrayData = { - if (ordinal < index) { - row.getArray(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getArray(ordinal - 1) - } - } - - override def getMap(ordinal: Int): MapData = { - if (ordinal < index) { - row.getMap(ordinal) - } else if (ordinal == index) { - fail() - } else { - row.getMap(ordinal - 1) - } - } - - override def get(ordinal: Int, dataType: DataType): AnyRef = { - if (ordinal < index) { - row.get(ordinal, dataType) - } else if (ordinal == index) { - if (dataType == StringType) { - record - } else { - fail() - } - } else { - row.get(ordinal - 1, dataType) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3ba7101f8ae7..148683d0857b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -384,7 +384,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val corruptFieldIndex = schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => @@ -392,7 +391,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parser = new FailureSafeParser[String]( input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, - corruptFieldIndex) + schema, + parsedOptions.columnNameOfCorruptRecord) iter.flatMap(parser.parse) } @@ -443,7 +443,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val corruptFieldIndex = schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) @@ -454,7 +453,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parser = new FailureSafeParser[String]( input => Seq(rawParser.parse(input)), parsedOptions.parseMode, - corruptFieldIndex) + schema, + parsedOptions.columnNameOfCorruptRecord) iter.flatMap(parser.parse) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2576ab714911..63af18ec5b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -115,7 +115,7 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -125,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource { } val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, corruptFieldIndex) + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) } override def infer( @@ -192,12 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), parser.options.headerFlag, parser, - corruptFieldIndex) + schema) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 38197d348168..eef43c7629c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -102,7 +102,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val corruptFieldIndex = requiredSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -118,7 +117,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, corruptFieldIndex) + CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7fea2ed68e63..5658a379dc1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -50,8 +50,6 @@ class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) - private val emptyRow = new GenericInternalRow(requiredSchema.length) - // Retrieve the raw record string. private def getCurrentInput(): UTF8String = { UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) @@ -201,11 +199,11 @@ class UnivocityParser( } else { tokens.take(schema.length) } - def getPartialResult(): InternalRow = { + def getPartialResult(): Option[InternalRow] = { try { - convert(checkedTokens) + Some(convert(checkedTokens)) } catch { - case _: BadRecordException => emptyRow + case _: BadRecordException => None } } throw BadRecordException( @@ -221,7 +219,7 @@ class UnivocityParser( row } catch { case NonFatal(e) => - throw BadRecordException(() => getCurrentInput(), () => emptyRow, e) + throw BadRecordException(() => getCurrentInput(), () => None, e) } } } @@ -246,12 +244,13 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, - corruptFieldIndex) + schema, + parser.options.columnNameOfCorruptRecord) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => safeParser.parse(tokens) }.flatten @@ -288,7 +287,7 @@ private[csv] object UnivocityParser { lines: Iterator[String], shouldDropHeader: Boolean, parser: UnivocityParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -305,7 +304,8 @@ private[csv] object UnivocityParser { val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), parser.options.parseMode, - corruptFieldIndex) + schema, + parser.options.columnNameOfCorruptRecord) filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 55bb33adf46b..51e952c12202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -53,7 +53,7 @@ abstract class JsonDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: JacksonParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] final def inferSchema( sparkSession: SparkSession, @@ -132,13 +132,14 @@ object TextInputJsonDataSource extends JsonDataSource { conf: Configuration, file: PartitionedFile, parser: JacksonParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), parser.options.parseMode, - corruptFieldIndex) + schema, + parser.options.columnNameOfCorruptRecord) linesReader.flatMap(safeParser.parse) } @@ -190,7 +191,7 @@ object WholeFileJsonDataSource extends JsonDataSource { conf: Configuration, file: PartitionedFile, parser: JacksonParser, - corruptFieldIndex: Option[Int]): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -202,9 +203,10 @@ object WholeFileJsonDataSource extends JsonDataSource { val safeParser = new FailureSafeParser[InputStream]( input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), parser.options.parseMode, - corruptFieldIndex) + schema, + parser.options.columnNameOfCorruptRecord) safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)).toIterator + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a99f85cc8c53..53d62d88b04c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -102,12 +102,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val corruptFieldIndex = requiredSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side - corruptFieldIndex.foreach { i => - val f = dataSchema(i) + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) if (f.dataType != StringType || !f.nullable) { throw new AnalysisException( "The field for corrupt records must be string type and nullable") @@ -120,7 +119,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { broadcastedHadoopConf.value.value, file, parser, - corruptFieldIndex) + requiredSchema) } } From aa6736f7c17decbc20438091c6029293cbc8c1fa Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 Mar 2017 13:25:27 +0800 Subject: [PATCH 4/6] address comments --- .../sql/catalyst/util/FailureSafeParser.scala | 24 ++++++++++++++----- .../datasources/json/JsonSuite.scala | 8 +++---- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 2cd320636221..4b787023dbc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -31,21 +31,26 @@ class FailureSafeParser[IN]( private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { if (corruptFieldIndex.isDefined) { (row, badRecord) => { - for ((f, i) <- actualSchema.zipWithIndex) { + var i = 0 + while (i < actualSchema.length) { + val f = actualSchema(i) resultRow(schema.fieldIndex(f.name)) = row.map(_.get(i, f.dataType)).orNull + i += 1 } resultRow(corruptFieldIndex.get) = badRecord() resultRow } } else { - (row, badRecord) => row.getOrElse { - for (i <- schema.indices) resultRow.setNullAt(i) - resultRow - } + (row, badRecord) => row.getOrElse(nullResult) } } @@ -57,11 +62,18 @@ class FailureSafeParser[IN]( Iterator(toResultRow(e.partialResult(), e.record)) case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => Iterator.empty - // If the parse mode is FAIL FAST, do not catch the exception. + case e: BadRecordException => throw e.cause } } } +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ case class BadRecordException( record: () => UTF8String, partialResult: () => Option[InternalRow], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c10cef56806..56fcf773f7dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionOne.getMessage.contains("BadRecordException")) + assert(exceptionOne.getMessage.contains("JsonParseException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("BadRecordException")) + assert(exceptionTwo.getMessage.contains("JsonParseException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionOne.getMessage.contains("BadRecordException")) + assert(exceptionOne.getMessage.contains("Failed to parse a value")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("BadRecordException")) + assert(exceptionTwo.getMessage.contains("Failed to parse a value")) } } From 20ac52f249477d17b0b07ff81bddeef444d5b546 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 Mar 2017 18:29:42 +0800 Subject: [PATCH 5/6] address comments --- .../apache/spark/sql/catalyst/util/FailureSafeParser.scala | 6 +++--- .../sql/execution/datasources/csv/UnivocityParser.scala | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 4b787023dbc8..5b222f25a29e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( - func: IN => Seq[InternalRow], + rawParser: IN => Seq[InternalRow], mode: String, schema: StructType, columnNameOfCorruptRecord: String) { @@ -50,13 +50,13 @@ class FailureSafeParser[IN]( resultRow } } else { - (row, badRecord) => row.getOrElse(nullResult) + (row, _) => row.getOrElse(nullResult) } } def parse(input: IN): Iterator[InternalRow] = { try { - func(input).toIterator.map(row => toResultRow(Some(row), () => null)) + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { case e: BadRecordException if ParseModes.isPermissiveMode(mode) => Iterator(toResultRow(e.partialResult(), e.record)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 5658a379dc1a..fc0ab08ff14f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -212,9 +212,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { - for (i <- requiredSchema.indices) { + var i = 0 + while (i < requiredSchema.length) { val from = tokenIndexArr(i) row(i) = valueConverters(from).apply(tokens(from)) + i += 1 } row } catch { From adf7d333bf97de41843c2f6d4cece9415253b81f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 21 Mar 2017 09:31:22 +0800 Subject: [PATCH 6/6] minor comments --- .../spark/sql/catalyst/util/FailureSafeParser.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameReader.scala | 10 ++++++---- .../execution/datasources/csv/UnivocityParser.scala | 6 +++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 5b222f25a29e..e8da10d65ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -42,8 +42,8 @@ class FailureSafeParser[IN]( (row, badRecord) => { var i = 0 while (i < actualSchema.length) { - val f = actualSchema(i) - resultRow(schema.fieldIndex(f.name)) = row.map(_.get(i, f.dataType)).orNull + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull i += 1 } resultRow(corruptFieldIndex.get) = badRecord() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 148683d0857b..767a636d7073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -383,11 +383,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) - val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val rawParser = new JacksonParser(dataSchema, parsedOptions) + val rawParser = new JacksonParser(actualSchema, parsedOptions) val parser = new FailureSafeParser[String]( input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, @@ -442,14 +443,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) - val dataSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => - val rawParser = new UnivocityParser(dataSchema, parsedOptions) + val rawParser = new UnivocityParser(actualSchema, parsedOptions) val parser = new FailureSafeParser[String]( input => Seq(rawParser.parse(input)), parsedOptions.parseMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index fc0ab08ff14f..263f77e11c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -51,7 +51,7 @@ class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. - private def getCurrentInput(): UTF8String = { + private def getCurrentInput: UTF8String = { UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) } @@ -207,7 +207,7 @@ class UnivocityParser( } } throw BadRecordException( - () => getCurrentInput(), + () => getCurrentInput, getPartialResult, new RuntimeException("Malformed CSV record")) } else { @@ -221,7 +221,7 @@ class UnivocityParser( row } catch { case NonFatal(e) => - throw BadRecordException(() => getCurrentInput(), () => None, e) + throw BadRecordException(() => getCurrentInput, () => None, e) } } }