Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
HyukjinKwon committed Feb 8, 2018
commit 36617e4bd864e0fbca5c617d009de45a8231a5d6
23 changes: 12 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3718,7 +3718,7 @@ def foo(x):
@pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
def foo(x):
return x
self.assertEqual(foo.returnType, schema[0].dataType)
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
Expand Down Expand Up @@ -4032,7 +4032,7 @@ def test_vectorized_udf_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))

def test_vectorized_udf_return_scalar(self):
Expand Down Expand Up @@ -4072,13 +4072,13 @@ def test_vectorized_udf_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*MapType'):
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))

with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a scalar Pandas UDF.*BinaryType'):
'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
pandas_udf(lambda x: x, BinaryType())

def test_vectorized_udf_dates(self):
Expand Down Expand Up @@ -4296,7 +4296,7 @@ def data(self):
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')

def test_simple(self):
def test_supported_types(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))
Copy link
Contributor

@icexelloss icexelloss Feb 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: It seems a bit arbitrary to mix array type in this test. Array probably belongs to a new test (if it doesn't exist yet) test_array, test_complex_types sth like test_all_types


Expand Down Expand Up @@ -4412,7 +4412,7 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
Expand Down Expand Up @@ -4448,7 +4448,7 @@ def test_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*MapType'):
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)

schema = StructType(
Expand All @@ -4457,7 +4457,7 @@ def test_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*a grouped map Pandas UDF.*ArrayType.*TimestampType'):
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)


Expand Down Expand Up @@ -4590,9 +4590,10 @@ def test_unsupported_types(self):

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(ArrayType(ArrayType(TimestampType())), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return v
pandas_udf(
lambda x: x,
ArrayType(ArrayType(TimestampType())),
PandasUDFType.GROUPED_AGG)

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
Expand Down
24 changes: 12 additions & 12 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,30 @@ def returnType(self):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)

if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with scalar Pandas UDFs: %s is "
"not supported" % str(self._returnType_placeholder))
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped map Pandas UDF: "
"Invalid returnType with grouped map Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))
else:
raise TypeError("Invalid returnType for a grouped map Pandas "
"UDF: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a scalar Pandas UDF: %s is "
"not supported" % str(self._returnType_placeholder))
raise TypeError("Invalid returnType for grouped map Pandas "
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with a grouped aggregate Pandas UDF: "
"Invalid returnType with grouped aggregate Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))

return self._returnType_placeholder
Expand Down