diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 99d8fa3a5b73..3bfda3bfa673 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -321,7 +321,7 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - if isinstance(paramName, str): + if isinstance(paramName, basestring): p = getattr(self, paramName, None) return isinstance(p, Param) else: @@ -393,7 +393,7 @@ def _resolveParam(self, param): if isinstance(param, Param): self._shouldOwn(param) return param - elif isinstance(param, str): + elif isinstance(param, basestring): return self.getParam(param) else: raise ValueError("Cannot resolve %r as a param." % param) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1fe..401350984dd2 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -319,6 +320,20 @@ def test_hasparam(self): testParams = TestParams() self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) def test_params(self): testParams = TestParams() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed224..c1452c90dbdc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -697,7 +697,7 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ """ - if not isinstance(col, str): + if not isinstance(col, basestring): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) @@ -1509,18 +1509,18 @@ def approxQuantile(self, col, probabilities, relativeError): Added support for multiple columns. """ - if not isinstance(col, (str, list, tuple)): + if not isinstance(col, (basestring, list, tuple)): raise ValueError("col should be a string, list or tuple, but got %r" % type(col)) - isStr = isinstance(col, str) + isStr = isinstance(col, basestring) if isinstance(col, tuple): col = list(col) - elif isinstance(col, str): + elif isStr: col = [col] for c in col: - if not isinstance(c, str): + if not isinstance(c, basestring): raise ValueError("columns should be strings, but got %r" % type(c)) col = _to_list(self._sc, col) @@ -1552,9 +1552,9 @@ def corr(self, col1, col2, method=None): :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") if not method: method = "pearson" @@ -1572,9 +1572,9 @@ def cov(self, col1, col2): :param col1: The name of the first column :param col2: The name of the second column """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) @@ -1594,9 +1594,9 @@ def crosstab(self, col1, col2): :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e..6791446265dd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1079,11 +1079,12 @@ def test_first_last_ignorenulls(self): def test_approxQuantile(self): df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) + for f in ["a", u"a"]: + aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1) + aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqs, list)) self.assertEqual(len(aqs), 2) self.assertTrue(isinstance(aqs[0], list)) @@ -1092,7 +1093,7 @@ def test_approxQuantile(self): self.assertTrue(isinstance(aqs[1], list)) self.assertEqual(len(aqs[1]), 3) self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1) + aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqt, list)) self.assertEqual(len(aqt), 2) self.assertTrue(isinstance(aqt[0], list)) @@ -1108,17 +1109,22 @@ def test_approxQuantile(self): def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr("a", "b") + corr = df.stat.corr(u"a", "b") self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_sampleby(self): + df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() + sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) + self.assertTrue(sampled.count() == 3) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov("a", "b") + cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab("a", "b").collect() + ct = df.stat.crosstab(u"a", "b").collect() ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): self.assertEqual(row[0], str(i))