Skip to content

Commit a1e4f62

Browse files
committed
Add check if the length of returned value is the same as input value.
1 parent a2a3f82 commit a1e4f62

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3245,6 +3245,15 @@ def test_vectorized_udf_exception(self):
32453245
with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
32463246
df.select(raise_exception(col('id'))).collect()
32473247

3248+
def test_vectorized_udf_invalid_length(self):
3249+
import pandas as pd
3250+
df = self.spark.range(10)
3251+
raise_exception = pandas_udf(lambda size: pd.Series(1), LongType())
3252+
with QuietTest(self.sc):
3253+
with self.assertRaisesRegexp(Exception,
3254+
'The length of returned value should be the same as input value'):
3255+
df.select(raise_exception()).collect()
3256+
32483257

32493258
if __name__ == "__main__":
32503259
from pyspark.sql.tests import *

python/pyspark/worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,16 @@ def read_vectorized_udfs(pickleSer, infile):
138138
else:
139139
args = ["a[0]"]
140140
call_udfs.append("f%d(%s)" % (i, ", ".join(args)))
141+
def chk_len(v, size):
142+
if len(v) == size:
143+
return v
144+
else:
145+
raise Exception("The length of returned value should be the same as input value")
146+
call_and_chk_len = ['chk_len(%s, a[0])' % call_udf for call_udf in call_udfs]
147+
udfs['chk_len'] = chk_len
141148
# Create function like this:
142-
# lambda a: [f0(a[0]), f1(a[1], a[2]), f2(a[3])]
143-
mapper_str = "lambda a: [%s]" % (", ".join(call_udfs))
149+
# lambda a: [chk_len(f0(a[0]), a[0]), chk_len(f1(a[1], a[2]), a[0]), ...]
150+
mapper_str = "lambda a: [%s]" % (", ".join(call_and_chk_len))
144151
mapper = eval(mapper_str, udfs)
145152

146153
func = lambda _, it: map(mapper, it)

0 commit comments

Comments
 (0)