Skip to content
31 changes: 29 additions & 2 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``inferSchema`` is enabled. To avoid going through the entire data once, disable
``inferSchema`` option or specify the schema explicitly using ``schema``.

:param path: string, or list of strings, for input path(s).
:param path: string, or list of strings, for input path(s),
or RDD of Strings storing CSV rows.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param sep: sets the single character as a separator for each field and value.
Expand Down Expand Up @@ -408,6 +409,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
>>> rdd = sc.textFile('python/test_support/sql/ages.csv')
>>> df2 = spark.read.csv(rdd)
>>> df2.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
Expand All @@ -420,7 +425,29 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
if type(path) == list:
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
def func(iterator):
for x in iterator:
if not isinstance(x, basestring):
x = unicode(x)
if isinstance(x, unicode):
x = x.encode("utf-8")
yield x
keyed = path.mapPartitions(func)
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a way within Python and this seems working:

diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 1ed452d895b..0f54065b3ee 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -438,7 +438,10 @@ class DataFrameReader(OptionUtils):
             keyed = path.mapPartitions(func)
             keyed._bypass_serializer = True
             jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
-            return self._df(self._jreader.csv(jrdd))
+            jdataset = self._spark._jsqlContext.createDataset(
+                jrdd.rdd(),
+                self._spark._sc._jvm.Encoders.STRING())
+            return self._df(self._jreader.csv(jdataset))
         else:
             raise TypeError("path can be only string, list or RDD")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldmedal, it'd be great if you could double check whether this really works and it can be shorten or cleaner. This was just my rough try only to reach the goal so I am not sure if it is the best way.

Copy link
Contributor Author

@goldmedal goldmedal Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, This way looks good. I'll try it. Thanks for your suggestion.

# see SPARK-22112
# There aren't any jvm api for creating a dataframe from rdd storing csv.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix these comments like,

SPARK-22112: There aren't any jvm api for creating a dataframe from rdd storing csv.
...

or

There aren't any jvm api ...
...
for creating a dataframe from dataset storing csv. See SPARK-22112.

when we happened to fix some code around here or review other PRs fixing some codes around here in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks

# We can do it through creating a jvm dataset firstly and using the jvm api
# for creating a dataframe from dataset storing csv.
jdataset = self._spark._ssql_ctx.createDataset(
jrdd.rdd(),
self._spark._jvm.Encoders.STRING())
return self._df(self._jreader.csv(jdataset))
else:
raise TypeError("path can be only string, list or RDD")

@since(1.5)
def orc(self, path):
Expand Down