diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ade4864e1d78..e95f359d1f06 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -331,7 +331,7 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - if isinstance(paramName, str): + if isinstance(paramName, basestring): p = getattr(self, paramName, None) return isinstance(p, Param) else: @@ -405,7 +405,7 @@ def _resolveParam(self, param): if isinstance(param, Param): self._shouldOwn(param) return param - elif isinstance(param, str): + elif isinstance(param, basestring): return self.getParam(param) else: raise ValueError("Cannot resolve %r as a param." % param) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e23354985088..05fdc13d5516 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -320,6 +320,13 @@ def test_hasparam(self): testParams = TestParams() self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) def test_params(self): testParams = TestParams() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0ac481a8a8b5..5a1b68730cd1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -537,7 +537,7 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ """ - if not isinstance(col, str): + if not isinstance(col, basestring): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) @@ -1263,7 +1263,7 @@ def approxQuantile(self, col, probabilities, relativeError): accepted but give the same result as 1. :return: the approximate quantiles at the given probabilities """ - if not isinstance(col, str): + if not isinstance(col, basestring): raise ValueError("col should be a string.") if not isinstance(probabilities, (list, tuple)): @@ -1293,9 +1293,9 @@ def corr(self, col1, col2, method=None): :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") if not method: method = "pearson" @@ -1313,9 +1313,9 @@ def cov(self, col1, col2): :param col1: The name of the first column :param col2: The name of the second column """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) @@ -1335,9 +1335,9 @@ def crosstab(self, col1, col2): :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9e455565a6c..3dace6c998a2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -799,7 +799,7 @@ def test_first_last_ignorenulls(self): def test_approxQuantile(self): df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF() - aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) + aq = df.stat.approxQuantile(u"a", [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aq, list)) self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) @@ -807,17 +807,22 @@ def test_approxQuantile(self): def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr("a", "b") + corr = df.stat.corr(u"a", "b") self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_sampleby(self): + df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() + corr = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) + self.assertTrue(corr.count() == 3) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov("a", "b") + cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab("a", "b").collect() + ct = df.stat.crosstab(u"a", "b").collect() ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): self.assertEqual(row[0], str(i)) @@ -883,9 +888,9 @@ def test_between_function(self): def test_struct_type(self): from pyspark.sql.types import StructType, StringType, StructField - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct1 = StructType().add(u"f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) + StructField(u"f2", StringType(), True, None)]) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) @@ -916,6 +921,7 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[u"f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) with self.assertRaises(KeyError): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4a023123b6ec..2cd4e5985651 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -509,10 +509,10 @@ def add(self, field, data_type=None, nullable=True, metadata=None): self.fields.append(field) self.names.append(field.name) else: - if isinstance(field, str) and data_type is None: + if isinstance(field, basestring) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") - if isinstance(data_type, str): + if isinstance(data_type, basestring): data_type_f = _parse_datatype_json_value(data_type) else: data_type_f = data_type @@ -531,7 +531,7 @@ def __len__(self): def __getitem__(self, key): """Access fields by name or slice.""" - if isinstance(key, str): + if isinstance(key, basestring): for field in self: if field.name == key: return field