Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Making JSON reader respectful to the caseSensitive parameter
  • Loading branch information
MaxGekk committed Apr 14, 2018
commit 863ace7afdcdc3b51aab1fa6320b16559112c02c
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ case class JsonToStructs(
// can generate incorrect files if values are missing in columns declared as non-nullable.
val nullableSchema = if (forceNullableSchema) schema.asNullable else schema

val caseSensitive = SQLConf.get.getConf(SQLConf.CASE_SENSITIVE)

override def nullable: Boolean = true

// Used in `FunctionRegistry`
Expand Down Expand Up @@ -567,7 +569,8 @@ case class JsonToStructs(
lazy val parser =
new JacksonParser(
rowSchema,
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get))
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get),
caseSensitive)

override def dataType: DataType = nullableSchema

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import org.apache.spark.util.Utils
*/
class JacksonParser(
schema: StructType,
val options: JSONOptions) extends Logging {
val options: JSONOptions,
caseSensitive: Boolean) extends Logging {

import JacksonUtils._
import com.fasterxml.jackson.core.JsonToken._
Expand Down Expand Up @@ -281,6 +282,14 @@ class JacksonParser(
s"Failed to parse a value for data type ${dataType.catalogString} (current token: $token).")
}

private def getCurrentName(parser: JsonParser): String = {
if (caseSensitive) {
parser.getCurrentName
} else {
parser.getCurrentName.toLowerCase
}
}

/**
* Parse an object from the token stream into a new Row representing the schema.
* Fields in the json that are not defined in the requested schema will be dropped.
Expand All @@ -291,7 +300,7 @@ class JacksonParser(
fieldConverters: Array[ValueConverter]): InternalRow = {
val row = new GenericInternalRow(schema.length)
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
schema.getFieldIndex(getCurrentName(parser)) match {
case Some(index) =>
row.update(index, fieldConverters(index).apply(parser))

Expand All @@ -312,7 +321,7 @@ class JacksonParser(
val keys = ArrayBuffer.empty[UTF8String]
val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
keys += UTF8String.fromString(parser.getCurrentName)
keys += UTF8String.fromString(getCurrentName(parser))
values += fieldConverter.apply(parser)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

val schema = userSpecifiedSchema.getOrElse {
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions, caseSensitive)
}

verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
Expand All @@ -430,7 +431,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val createParser = CreateJacksonParser.string _
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val rawParser = new JacksonParser(actualSchema, parsedOptions)
val rawParser = new JacksonParser(actualSchema, parsedOptions, caseSensitive)
val parser = new FailureSafeParser[String](
input => rawParser.parse(input, createParser, UTF8String.fromString),
parsedOptions.parseMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ object TextInputJsonDataSource extends JsonDataSource {
parsedOptions: JSONOptions): StructType = {
val json: Dataset[String] = createBaseDataset(
sparkSession, inputPaths, parsedOptions.lineSeparator)
inferFromDataset(json, parsedOptions)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

inferFromDataset(json, parsedOptions, caseSensitive)
}

def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions,
caseSensitive: Boolean): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String, caseSensitive)
}

private def createBaseDataset(
Expand Down Expand Up @@ -153,7 +156,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
parsedOptions: JSONOptions): StructType = {
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
JsonInferSchema.infer(sampled, parsedOptions, createParser)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

JsonInferSchema.infer(sampled, parsedOptions, createParser, caseSensitive)
}

private def createBaseRdd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

(file: PartitionedFile) => {
val parser = new JacksonParser(actualSchema, parsedOptions)
val parser = new JacksonParser(actualSchema, parsedOptions, caseSensitive)
JsonDataSource(parsedOptions).readFile(
broadcastedHadoopConf.value.value,
file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ private[sql] object JsonInferSchema {
def infer[T](
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
createParser: (JsonFactory, T) => JsonParser,
caseSensitive: Boolean): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord

Expand All @@ -53,7 +54,7 @@ private[sql] object JsonInferSchema {
try {
Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
Some(inferField(parser, configOptions))
Some(inferField(parser, configOptions, caseSensitive))
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
Expand Down Expand Up @@ -98,14 +99,15 @@ private[sql] object JsonInferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
private def inferField(parser: JsonParser, configOptions: JSONOptions,
caseSensitive: Boolean): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType

case FIELD_NAME =>
parser.nextToken()
inferField(parser, configOptions)
inferField(parser, configOptions, caseSensitive)

case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
Expand All @@ -121,8 +123,8 @@ private[sql] object JsonInferSchema {
val builder = Array.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, configOptions),
if (caseSensitive) parser.getCurrentName else parser.getCurrentName.toLowerCase,
inferField(parser, configOptions, caseSensitive),
nullable = true)
}
val fields: Array[StructField] = builder.result()
Expand All @@ -137,7 +139,7 @@ private[sql] object JsonInferSchema {
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(
elementType, inferField(parser, configOptions))
elementType, inferField(parser, configOptions, caseSensitive))
}

ArrayType(elementType)
Expand Down
Loading