Skip to content
Closed
45 changes: 28 additions & 17 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def load(self, path=None, format=None, schema=None, **options):
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None):
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None):
"""
Loads a JSON file (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
Expand Down Expand Up @@ -204,11 +205,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.

>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -225,7 +228,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat)
timestampFormat=timestampFormat, timeZone=timeZone)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down Expand Up @@ -297,7 +300,7 @@ def text(self, paths):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
negativeInf=None, dateFormat=None, timestampFormat=None, timeZone=None, maxColumns=None,
Copy link
Member

@HyukjinKwon HyukjinKwon Jan 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to my knowledge, this should be added at the end to prevent breaking the existing codes that use those options by positional arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, I'll move them to the end.

maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.

Expand Down Expand Up @@ -341,11 +344,13 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
Expand All @@ -372,8 +377,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone,
maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
path = [path]
Expand Down Expand Up @@ -591,7 +596,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
self._jwrite.saveAsTable(name)

@since(1.4)
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None):
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
timeZone=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.

:param path: the path in any Hadoop supported file system
Expand All @@ -607,17 +613,20 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to format timestamps.
If None is set, it uses the default value, session local timezone.

>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat)
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
timeZone=timeZone)
self._jwrite.json(path)

@since(1.4)
Expand Down Expand Up @@ -664,7 +673,7 @@ def text(self, path, compression=None):
@since(2.0)
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None):
timestampFormat=None, timeZone=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.

:param path: the path in any Hadoop supported file system
Expand Down Expand Up @@ -699,18 +708,20 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.

>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll,
dateFormat=dateFormat, timestampFormat=timestampFormat)
dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone)
self._jwrite.csv(path)

@since(1.5)
Expand Down
24 changes: 14 additions & 10 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None,
timestampFormat=None):
timestampFormat=None, timeZone=None):
"""
Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) and returns a :class`DataFrame`.
Expand Down Expand Up @@ -476,11 +476,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.

>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -494,7 +496,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat)
timestampFormat=timestampFormat, timeZone=timeZone)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down Expand Up @@ -551,8 +553,8 @@ def text(self, path):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None):
negativeInf=None, dateFormat=None, timestampFormat=None, timeZone=None,
maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.

This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -597,11 +599,13 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param dateFormat: sets the string that indicates a date format. Custom date formats
follow the formats at ``java.text.SimpleDateFormat``. This
applies to date type. If None is set, it uses the
default value value, ``yyyy-MM-dd``.
default value, ``yyyy-MM-dd``.
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
Expand All @@ -626,8 +630,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone,
maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, ParseModes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -482,19 +482,40 @@ case class JsonTuple(children: Seq[Expression])
/**
* Converts an json input string to a [[StructType]] with the specified schema.
*/
case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression)
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
case class JsonToStruct(
schema: StructType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

def this(schema: StructType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)

@transient
lazy val optionsWithTimeZone = {
val caseInsensitiveOptions: CaseInsensitiveMap =
new CaseInsensitiveMap(options + ("mode" -> ParseModes.FAIL_FAST_MODE))
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(caseInsensitiveOptions + ("timeZone" -> timeZoneId.get))
}
}

@transient
lazy val parser =
new JacksonParser(
schema,
"invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE)))
new JSONOptions(optionsWithTimeZone))

override def dataType: DataType = schema

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = {
try parser.parse(json.toString).head catch {
case _: SparkSQLJsonProcessingException => null
Expand All @@ -507,10 +528,25 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
/**
* Converts a [[StructType]] to a json output string.
*/
case class StructToJson(options: Map[String, String], child: Expression)
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
case class StructToJson(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

def this(options: Map[String, String], child: Expression) = this(options, child, None)

@transient
lazy val optionsWithTimeZone = {
val caseInsensitiveOptions: CaseInsensitiveMap = new CaseInsensitiveMap(options)
if (caseInsensitiveOptions.contains("timeZone")) {
caseInsensitiveOptions
} else {
new CaseInsensitiveMap(caseInsensitiveOptions + ("timeZone" -> timeZoneId.get))
}
}

@transient
lazy val writer = new CharArrayWriter()

Expand All @@ -519,7 +555,7 @@ case class StructToJson(options: Map[String, String], child: Expression)
new JacksonGenerator(
child.dataType.asInstanceOf[StructType],
writer,
new JSONOptions(options))
new JSONOptions(optionsWithTimeZone))

override def dataType: DataType = StringType

Expand All @@ -538,6 +574,9 @@ case class StructToJson(options: Map[String, String], child: Expression)
}
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(row: Any): Any = {
gen.write(row.asInstanceOf[InternalRow])
gen.flush()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.json

import java.util.Locale
import java.util.{Locale, TimeZone}

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.commons.lang3.time.FastDateFormat
Expand Down Expand Up @@ -58,13 +58,15 @@ private[sql] class JSONOptions(
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")

val timeZone: TimeZone = TimeZone.getTimeZone(parameters("timeZone"))

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we don't need timezone here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a combination of the dateFormat and DateTimeUtils.millisToDays() (see JacksonParser.scala#L251 or UnivocityParser.scala#L137).

If both timezones of the dateFormat and DateTimeUtils.millisToDays() are the same, the days will be calculated correctly.
Here the dateFormat will have the default timezone to parse and DateTimeUtils.millisToDays() will also use the default timezone to calculate days here.


val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US)
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)

// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
Expand Down
Loading