Skip to content
25 changes: 24 additions & 1 deletion python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``inferSchema`` option or specify the schema explicitly using ``schema``.

:param path: string, or list of strings, for input path(s).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: . -> ,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks :)

or RDD of Strings storing CSV rows.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param sep: sets the single character as a separator for each field and value.
Expand Down Expand Up @@ -408,6 +409,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
>>> rdd = sc.textFile('python/test_support/sql/ages.csv')
>>> df2 = spark.read.csv(rdd)
>>> df2.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
Expand All @@ -420,7 +425,25 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
if type(path) == list:
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
def func(iterator):
for x in iterator:
if not isinstance(x, basestring):
x = unicode(x)
if isinstance(x, unicode):
x = x.encode("utf-8")
yield x
keyed = path.mapPartitions(func)
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a way within Python and this seems working:

diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 1ed452d895b..0f54065b3ee 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -438,7 +438,10 @@ class DataFrameReader(OptionUtils):
             keyed = path.mapPartitions(func)
             keyed._bypass_serializer = True
             jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
-            return self._df(self._jreader.csv(jrdd))
+            jdataset = self._spark._jsqlContext.createDataset(
+                jrdd.rdd(),
+                self._spark._sc._jvm.Encoders.STRING())
+            return self._df(self._jreader.csv(jdataset))
         else:
             raise TypeError("path can be only string, list or RDD")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldmedal, it'd be great if you could double check whether this really works and it can be shorten or cleaner. This was just my rough try only to reach the goal so I am not sure if it is the best way.

Copy link
Contributor Author

@goldmedal goldmedal Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, This way looks good. I'll try it. Thanks for your suggestion.

jdataset = self._spark._ssql_ctx.createDataset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a small comment here to explain why we should create the dataset (which could look a bit weird in PySpark I believe).

jrdd.rdd(),
self._spark._sc._jvm.Encoders.STRING())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we replace _spark._sc._jvm to _spark._jvm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's work. I'll modify it.

return self._df(self._jreader.csv(jdataset))
else:
raise TypeError("path can be only string, list or RDD")

@since(1.5)
def orc(self, path):
Expand Down