Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,14 @@ def _create_judf(self):
return judf

def __call__(self, *cols):
for c in cols:
if not isinstance(c, (Column, str)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this break unicode support in Python 2?

from pyspark.sql.functions import udf
udf(lambda x: x)(u"a")

before

Column<<lambda>(a)>

after

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../spark/python/pyspark/sql/functions.py", line 1970, in wrapper
    return self(*args)
  File ".../spark/python/pyspark/sql/functions.py", line 1958, in __call__
    "lit, array, struct or create_map.".format(c, type(c)))
TypeError: Invalid UDF argument, not a str or Column: a of type <type 'unicode'>. For Column literals use sql.functions lit, array, struct or create_map.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Sorry for a delayed response, I am seldom online these days. You're right, it looks like an issue. I'll take a look at this, when I have more time

raise TypeError(
"Invalid UDF argument, not a str or Column: "
"{0} of type {1}. "
"For Column literals use sql.functions "
"lit, array, struct or create_map.".format(c, type(c)))

judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,11 @@ def f(x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_udf_should_validate_input_args(self):
from pyspark.sql.functions import udf

self.assertRaises(TypeError, udf(lambda x: x), None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should have positive tests for a column and a string as well as a negative test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is pretty well covered by existing udf tests. The more the merrier but I am not sure what can be added with duplicating other test cases.

Do you think we should try to some type validation of the number of arguments?

Pros:

  • It is easy to implement with inspect or func.__code__ for plain Python objects.
  • It is nice to fail without starting a complex job.

Cons:

  • It most likely won't work well for C extensions and such.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is covered by existing tests, then that's fine. Good point.

To validate number of args, I think it is a good idea, as long as we know that it won't fail C extensions (but may be inconclusive).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I am afraid it can actually cause more troubles than its worth:

  • If we throw an exception there is a chance we hit some border cases.
  • Issuing a warning doesn't prevent task failure so it doesn't provide the same advantages as failing early.

Maybe it is better to leave it as is. Right now users get a clear feedback, if there is an incorrect type, and for additional safety one can always use annotations and type checker.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It's probably worth exploring eventually, but there's no need to hold up this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MTE I removed [WIP] and hopefully it will get merged :)


def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down