Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def hasParam(self, paramName):
Tests whether this instance contains a param with a given
(string) name.
"""
if isinstance(paramName, str):
if isinstance(paramName, basestring):
p = getattr(self, paramName, None)
return isinstance(p, Param)
else:
Expand Down Expand Up @@ -393,7 +393,7 @@ def _resolveParam(self, param):
if isinstance(param, Param):
self._shouldOwn(param)
return param
elif isinstance(param, str):
elif isinstance(param, basestring):
return self.getParam(param)
else:
raise ValueError("Cannot resolve %r as a param." % param)
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
Expand Down Expand Up @@ -319,6 +320,20 @@ def test_hasparam(self):
testParams = TestParams()
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
self.assertFalse(testParams.hasParam("notAParameter"))
self.assertTrue(testParams.hasParam(u"maxIter"))

def test_resolveparam(self):
testParams = TestParams()
self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)

self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
if sys.version_info[0] >= 3:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk, would this test address your concern enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! :)

# In Python 3, it is allowed to get/set attributes with non-ascii characters.
e_cls = AttributeError
else:
e_cls = UnicodeEncodeError
self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))

def test_params(self):
testParams = TestParams()
Expand Down
22 changes: 11 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def sampleBy(self, col, fractions, seed=None):
+---+-----+

"""
if not isinstance(col, str):
if not isinstance(col, basestring):
raise ValueError("col must be a string, but got %r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" % type(fractions))
Expand Down Expand Up @@ -1509,18 +1509,18 @@ def approxQuantile(self, col, probabilities, relativeError):
Added support for multiple columns.
"""

if not isinstance(col, (str, list, tuple)):
if not isinstance(col, (basestring, list, tuple)):
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))

isStr = isinstance(col, str)
isStr = isinstance(col, basestring)

if isinstance(col, tuple):
col = list(col)
elif isinstance(col, str):
elif isStr:
col = [col]

for c in col:
if not isinstance(c, str):
if not isinstance(c, basestring):
raise ValueError("columns should be strings, but got %r" % type(c))
col = _to_list(self._sc, col)

Expand Down Expand Up @@ -1552,9 +1552,9 @@ def corr(self, col1, col2, method=None):
:param col2: The name of the second column
:param method: The correlation method. Currently only supports "pearson"
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
if not method:
method = "pearson"
Expand All @@ -1572,9 +1572,9 @@ def cov(self, col1, col2):
:param col1: The name of the first column
:param col2: The name of the second column
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)

Expand All @@ -1594,9 +1594,9 @@ def crosstab(self, col1, col2):
:param col2: The name of the second column. Distinct items will make the column names
of the DataFrame.
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)

Expand Down
22 changes: 14 additions & 8 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,12 @@ def test_first_last_ignorenulls(self):

def test_approxQuantile(self):
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
for f in ["a", u"a"]:
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
self.assertTrue(all(isinstance(q, float) for q in aq))
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqs, list))
self.assertEqual(len(aqs), 2)
self.assertTrue(isinstance(aqs[0], list))
Expand All @@ -1092,7 +1093,7 @@ def test_approxQuantile(self):
self.assertTrue(isinstance(aqs[1], list))
self.assertEqual(len(aqs[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqt, list))
self.assertEqual(len(aqt), 2)
self.assertTrue(isinstance(aqt[0], list))
Expand All @@ -1108,17 +1109,22 @@ def test_approxQuantile(self):
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr("a", "b")
corr = df.stat.corr(u"a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)

def test_sampleby(self):
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(sampled.count() == 3)

def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")
cov = df.stat.cov(u"a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)

def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = df.stat.crosstab(u"a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
Expand Down