Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
def _set_csv_opts(self, schema, sep, encoding, quote, escape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could be:

def _set_csv_opts(self, schema, **options):
     if schema is not None:
          self.schema(schema)
     for k in options:
          if options[k] is not None:
               self.option(k, options[k])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. There are a bunch of things I want to do to the readwrite.py (mainly break it apart). I will do it there and merge this to unblock the rc.

comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, mode):
dateFormat, maxColumns, maxCharsPerColumn, maxMalformedLogPerPartition, mode):
"""
Set options based on the CSV optional parameters
"""
Expand Down Expand Up @@ -115,6 +115,8 @@ def _set_csv_opts(self, schema, sep, encoding, quote, escape,
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if maxMalformedLogPerPartition is not None:
self.option("maxMalformedLogPerPartition", maxMalformedLogPerPartition)
if mode is not None:
self.option("mode", mode)

Expand Down Expand Up @@ -268,10 +270,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
[('age', 'bigint'), ('name', 'string')]

"""
self._set_json_opts(schema, primitivesAsString, prefersDecimal,
allowComments, allowUnquotedFieldNames, allowSingleQuotes,
allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
mode, columnNameOfCorruptRecord)
self._set_json_opts(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @tdas

previously these options were too susceptible to positional change in the arg list.

schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down Expand Up @@ -343,7 +347,8 @@ def text(self, paths):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, mode=None):
negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.

This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -408,11 +413,13 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""

self._set_csv_opts(schema, sep, encoding, quote, escape,
comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, mode)
self._set_csv_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
Expand Down Expand Up @@ -958,10 +965,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
>>> json_sdf.schema == sdf_schema
True
"""
self._set_json_opts(schema, primitivesAsString, prefersDecimal,
allowComments, allowUnquotedFieldNames, allowSingleQuotes,
allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
mode, columnNameOfCorruptRecord)
self._set_json_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down Expand Up @@ -1019,7 +1028,8 @@ def text(self, path):
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, mode=None):
negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
maxMalformedLogPerPartition=None, mode=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.

This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -1085,11 +1095,13 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> csv_sdf.schema == sdf_schema
True
"""

self._set_csv_opts(schema, sep, encoding, quote, escape,
comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, mode)
self._set_csv_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
Expand Down