Skip to content
Closed
Prev Previous commit
Next Next commit
Modify JSONOptions/CSVOptions to take default timezone id argument.
  • Loading branch information
ueshin committed Feb 1, 2017
commit d5ab37c7e1bdcd79586b802f3450bfbc7a9a8f36
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, ParseModes}
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -493,23 +493,12 @@ case class JsonToStruct(
def this(schema: StructType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)

@transient
lazy val optionsWithTimeZone = {
val caseInsensitiveOptions: CaseInsensitiveMap =
new CaseInsensitiveMap(options + ("mode" -> ParseModes.FAIL_FAST_MODE))
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(caseInsensitiveOptions + ("timeZone" -> timeZoneId.get))
}
}

@transient
lazy val parser =
new JacksonParser(
schema,
"invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
new JSONOptions(optionsWithTimeZone))
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))

override def dataType: DataType = schema

Expand Down Expand Up @@ -537,16 +526,6 @@ case class StructToJson(

def this(options: Map[String, String], child: Expression) = this(options, child, None)

@transient
lazy val optionsWithTimeZone = {
val caseInsensitiveOptions: CaseInsensitiveMap = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(caseInsensitiveOptions + ("timeZone" -> timeZoneId.get))
}
}

@transient
lazy val writer = new CharArrayWriter()

Expand All @@ -555,7 +534,7 @@ case class StructToJson(
new JacksonGenerator(
child.dataType.asInstanceOf[StructType],
writer,
new JSONOptions(optionsWithTimeZone))
new JSONOptions(options, timeZoneId.get))

override def dataType: DataType = StringType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
@transient private val parameters: CaseInsensitiveMap)
@transient private val parameters: CaseInsensitiveMap, defaultTimeZoneId: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the timeZoneId just an option in parameters with key timeZoneId?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put the timeZone option every time creating JSONOptions (or CSVOptions), but there were the same contains-key check logic many times as @HyukjinKwon mentioned.
So I modified to pass the default timezone id to JSONOptions and CSVOptions.

Copy link
Member

@HyukjinKwon HyukjinKwon Feb 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, it needed to introduce such logics below before creating JSONOptions/CSVOptions.

val options = extraOptions.toMap
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
  caseInsensitiveOptions
} else {
  new CaseInsensitiveMap(
  options + ("timeZone" -> sparkSession.sessionState.conf.sessionLocalTimeZone))
}

val parsedOptions: JSONOptions = new JSONOptions(optionsWithTimeZone)

So, I suggested this way as It seems also because the default value of timeZone can be varied. It seems ParquetOptions.compressionCodecClassName also takes another argument for the same reason.

Another way I suggested is, to make this Option[TimeZone] to decouple the variant of the default value (like JSONOptions.columnNameOfCorruptRecord) but it seems timestampFormat in both options are dependent on timeZone. In that case, we should make it Option too which seems introducing some more complexity. So, it seems above way is better.

I am fine if we find a better cleaner way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To cut this short, I think we can resemble JSONOptions.columnNameOfCorruptRecord or ParquetOptions.compressionCodecClassName to deal with the variant of default value.

It seems now it resembles the latter.

extends Logging with Serializable {

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
def this(parameters: Map[String, String], defaultTimeZoneId: String) =
this(new CaseInsensitiveMap(parameters), defaultTimeZoneId)

val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
Expand All @@ -58,7 +59,7 @@ private[sql] class JSONOptions(
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")

val timeZone: TimeZone = TimeZone.getTimeZone(parameters("timeZone"))
val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId))

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
JsonToStruct(
schema,
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> tz.getID),
Literal(jsonData2)),
Literal(jsonData2),
gmtId),
InternalRow.fromSeq(c.getTimeInMillis * 1000L :: Nil)
)
}
Expand Down Expand Up @@ -456,13 +457,15 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StructToJson(
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> gmtId.get),
struct),
struct,
gmtId),
"""{"t":"2016-01-01T00:00:00"}"""
)
checkEvaluation(
StructToJson(
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> "PST"),
struct),
struct,
gmtId),
"""{"t":"2015-12-31T16:00:00"}"""
)
}
Expand Down
13 changes: 2 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,17 +332,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def json(jsonRDD: RDD[String]): DataFrame = {
val optionsWithTimeZone = {
val options = extraOptions.toMap
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(
options + ("timeZone" -> sparkSession.sessionState.conf.sessionLocalTimeZone))
}
}
val parsedOptions: JSONOptions = new JSONOptions(optionsWithTimeZone)
val parsedOptions: JSONOptions =
new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone)
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2670,11 +2670,12 @@ class Dataset[T] private[sql](
*/
def toJSON: Dataset[String] = {
val rowSchema = this.schema
val options = Map("timeZone" -> sparkSession.sessionState.conf.sessionLocalTimeZone)
val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone
val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter =>
val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
val gen = new JacksonGenerator(rowSchema, writer, new JSONOptions(options))
val gen = new JacksonGenerator(rowSchema, writer,
new JSONOptions(Map.empty[String, String], sessionLocalTimeZone))

new Iterator[String] {
override def hasNext: Boolean = iter.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
files: Seq[FileStatus]): Option[StructType] = {
require(files.nonEmpty, "Cannot infer schema from an empty set of files")

val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val csvOptions = new CSVOptions(optionsWithTimeZone)
val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
val paths = files.map(_.getPath.toString)
val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
val firstLine: String = findFirstLine(csvOptions, lines)
Expand Down Expand Up @@ -128,8 +127,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
val conf = job.getConfiguration
val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val csvOptions = new CSVOptions(optionsWithTimeZone)
val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
csvOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
Expand All @@ -156,8 +154,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val csvOptions = new CSVOptions(optionsWithTimeZone)
val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
val commentPrefix = csvOptions.comment.toString

val broadcastedHadoopConf =
Expand Down Expand Up @@ -234,18 +231,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {

schema.foreach(field => verifyType(field.dataType))
}

private def getOptionsWithTimeZone(
sparkSession: SparkSession,
options: Map[String, String]): CaseInsensitiveMap = {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(
options + ("timeZone" -> sparkSession.sessionState.conf.sessionLocalTimeZone))
}
}
}

private[csv] class CsvOutputWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}

private[csv] class CSVOptions(@transient private val parameters: CaseInsensitiveMap)
private[csv] class CSVOptions(
@transient private val parameters: CaseInsensitiveMap, defaultTimeZoneId: String)
extends Logging with Serializable {

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
def this(parameters: Map[String, String], defaultTimeZoneId: String) =
this(new CaseInsensitiveMap(parameters), defaultTimeZoneId)

private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
Expand Down Expand Up @@ -106,7 +108,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
name.map(CompressionCodecs.getCodecClassName)
}

val timeZone: TimeZone = TimeZone.getTimeZone(parameters("timeZone"))
val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId))

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
private[csv] class UnivocityGenerator(
schema: StructType,
writer: Writer,
options: CSVOptions = new CSVOptions(Map.empty[String, String])) {
options: CSVOptions) {
private val writerSettings = options.asWriterSettings
writerSettings.setHeaders(schema.fieldNames: _*)
private val gen = new CsvWriter(writer, writerSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand All @@ -47,8 +47,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
if (files.isEmpty) {
None
} else {
val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val parsedOptions: JSONOptions = new JSONOptions(optionsWithTimeZone)
val parsedOptions: JSONOptions =
new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
Expand All @@ -68,8 +68,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val conf = job.getConfiguration
val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val parsedOptions: JSONOptions = new JSONOptions(optionsWithTimeZone)
val parsedOptions: JSONOptions =
new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
Expand Down Expand Up @@ -99,8 +99,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

val optionsWithTimeZone = getOptionsWithTimeZone(sparkSession, options)
val parsedOptions: JSONOptions = new JSONOptions(optionsWithTimeZone)
val parsedOptions: JSONOptions =
new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)

Expand Down Expand Up @@ -132,18 +132,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
classOf[Text]).map(_._2.toString) // get the text line
}

private def getOptionsWithTimeZone(
sparkSession: SparkSession,
options: Map[String, String]): CaseInsensitiveMap = {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(
options + ("timeZone" -> sparkSession.sessionState.conf.sessionLocalTimeZone))
}
}

/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._
class CSVInferSchemaSuite extends SparkFunSuite {

test("String fields types are inferred correctly from null types") {
val options = new CSVOptions(Map("timeZone" -> "GMT"))
val options = new CSVOptions(Map.empty[String, String], "GMT")
assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType)
Expand All @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}

test("String fields types are inferred correctly from other types") {
val options = new CSVOptions(Map("timeZone" -> "GMT"))
val options = new CSVOptions(Map.empty[String, String], "GMT")
assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType)
Expand All @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}

test("Timestamp field types are inferred correctly via custom data format") {
var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm", "timeZone" -> "GMT"))
var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT")
assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
options = new CSVOptions(Map("timestampFormat" -> "yyyy", "timeZone" -> "GMT"))
options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT")
assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType)
}

test("Timestamp field types are inferred correctly from other types") {
val options = new CSVOptions(Map("timeZone" -> "GMT"))
val options = new CSVOptions(Map.empty[String, String], "GMT")
assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType)
}

test("Boolean fields types are inferred correctly from other types") {
val options = new CSVOptions(Map("timeZone" -> "GMT"))
val options = new CSVOptions(Map.empty[String, String], "GMT")
assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType)
}
Expand All @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}

test("Null fields are handled properly when a nullValue is specified") {
var options = new CSVOptions(Map("nullValue" -> "null", "timeZone" -> "GMT"))
var options = new CSVOptions(Map("nullValue" -> "null"), "GMT")
assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
assert(CSVInferSchema.inferField(StringType, "null", options) == StringType)
assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)

options = new CSVOptions(Map("nullValue" -> "\\N", "timeZone" -> "GMT"))
options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT")
assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
Expand All @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}

test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm", "timeZone" -> "GMT"))
val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT")
assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
}

test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") {
val options = new CSVOptions(Map("timeZone" -> "GMT"))
val options = new CSVOptions(Map.empty[String, String], "GMT")

// 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) ==
Expand Down
Loading