Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def hasParam(self, paramName):
Tests whether this instance contains a param with a given
(string) name.
"""
if isinstance(paramName, str):
if isinstance(paramName, basestring):
p = getattr(self, paramName, None)
return isinstance(p, Param)
else:
Expand Down Expand Up @@ -405,7 +405,7 @@ def _resolveParam(self, param):
if isinstance(param, Param):
self._shouldOwn(param)
return param
elif isinstance(param, str):
elif isinstance(param, basestring):
return self.getParam(param)
else:
raise ValueError("Cannot resolve %r as a param." % param)
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ def test_hasparam(self):
testParams = TestParams()
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
self.assertFalse(testParams.hasParam("notAParameter"))
self.assertTrue(testParams.hasParam(u"maxIter"))

def test_resolveparam(self):
testParams = TestParams()
self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)

def test_params(self):
testParams = TestParams()
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def sampleBy(self, col, fractions, seed=None):
+---+-----+

"""
if not isinstance(col, str):
if not isinstance(col, basestring):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this change is needed. Because I think in SQL the column name is only allowed with alphabet, digit and underline, so it is a question why users will use unicode string as column in particular.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to f958f27, it seems to be possible to use Non-ascii characters in column name.
I think there are use cases which want to use non-ascii character in column name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it. I just mean from SQL parser.

Similarly, as the unicode column name will be encoded by name.encode('utf-8'), it is now a str instance. In other words, the schema still stores column names as str. However, this change is allowing unicode input as col. I think there will be mismatching.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we don't need to do this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for answering. I understood why isinstance(col, basestring) is not needed here.

Although column name is basically stored as str, it is stored as unicode in a certain case.
See SPARK-15244 for details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some harm in allowing unicode here though? If my column is 'a' and I call sampleBy(u'a') it will work after this change, otherwise it will throw an error. I think it's better to treat 'a' and u'a' as equivalent...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. There is no problem caused by allowing unicode here.
As you mentioned, it's better to handle 'a' and u'a' because there are few cases that unicode is passed. (e.g. when __future__.unicode_literals is imported in Python 2.)

raise ValueError("col must be a string, but got %r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" % type(fractions))
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def approxQuantile(self, col, probabilities, relativeError):
accepted but give the same result as 1.
:return: the approximate quantiles at the given probabilities
"""
if not isinstance(col, str):
if not isinstance(col, basestring):
raise ValueError("col should be a string.")

if not isinstance(probabilities, (list, tuple)):
Expand Down Expand Up @@ -1293,9 +1293,9 @@ def corr(self, col1, col2, method=None):
:param col2: The name of the second column
:param method: The correlation method. Currently only supports "pearson"
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
if not method:
method = "pearson"
Expand All @@ -1313,9 +1313,9 @@ def cov(self, col1, col2):
:param col1: The name of the first column
:param col2: The name of the second column
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)

Expand All @@ -1335,9 +1335,9 @@ def crosstab(self, col1, col2):
:param col2: The name of the second column. Distinct items will make the column names
of the DataFrame.
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)

Expand Down
18 changes: 12 additions & 6 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,25 +799,30 @@ def test_first_last_ignorenulls(self):

def test_approxQuantile(self):
df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF()
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
aq = df.stat.approxQuantile(u"a", [0.1, 0.5, 0.9], 0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically in these tests the field names are all ascii characters. Is it possibly to add tests using non-ascii characters so we can make sure it works?

self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
self.assertTrue(all(isinstance(q, float) for q in aq))

def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr("a", "b")
corr = df.stat.corr(u"a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)

def test_sampleby(self):
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
corr = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(corr.count() == 3)

def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")
cov = df.stat.cov(u"a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)

def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = df.stat.crosstab(u"a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
Expand Down Expand Up @@ -883,9 +888,9 @@ def test_between_function(self):

def test_struct_type(self):
from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct1 = StructType().add(u"f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
StructField(u"f2", StringType(), True, None)])
self.assertEqual(struct1, struct2)

struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
Expand Down Expand Up @@ -916,6 +921,7 @@ def test_struct_type(self):

struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
self.assertIs(struct1["f1"], struct1.fields[0])
self.assertIs(struct1[u"f1"], struct1.fields[0])
self.assertIs(struct1[0], struct1.fields[0])
self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
with self.assertRaises(KeyError):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
self.fields.append(field)
self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
if isinstance(field, basestring) and data_type is None:
raise ValueError("Must specify DataType if passing name of struct_field to create.")

if isinstance(data_type, str):
if isinstance(data_type, basestring):
data_type_f = _parse_datatype_json_value(data_type)
else:
data_type_f = data_type
Expand All @@ -531,7 +531,7 @@ def __len__(self):

def __getitem__(self, key):
"""Access fields by name or slice."""
if isinstance(key, str):
if isinstance(key, basestring):
for field in self:
if field.name == key:
return field
Expand Down