Skip to content

Commit 8216b6b

Browse files
committed
wip
1 parent ea0a5ee commit 8216b6b

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

python/pyspark/sql/catalog.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,25 @@ def registerFunction(self, name, f, returnType=StringType()):
255255
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
256256
>>> spark.sql("SELECT stringLengthInt('test')").collect()
257257
[Row(stringLengthInt(test)=4)]
258+
259+
>>> import random
260+
>>> from pyspark.sql.functions import udf
261+
>>> from pyspark.sql.types import IntegerType, StringType
262+
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
263+
>>> spark.catalog.registerFunction("random_udf", random_udf, StringType())
264+
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
265+
[Row(random_udf()=u'82')]
258266
"""
259-
udf = UserDefinedFunction(f, returnType=returnType, name=name,
260-
evalType=PythonEvalType.SQL_BATCHED_UDF)
261-
self._jsparkSession.udf().registerPython(name, udf._judf)
262-
return udf._wrapped()
267+
268+
if hasattr(f, 'asNondeterministic'):
269+
udf = f._set_name(name, returnType)
270+
self._jsparkSession.udf().registerPython(name, udf._judf)
271+
return udf._wrapped()
272+
else:
273+
udf = UserDefinedFunction(f, returnType=returnType, name=name,
274+
evalType=PythonEvalType.SQL_BATCHED_UDF)
275+
self._jsparkSession.udf().registerPython(name, udf._judf)
276+
return udf._wrapped()
263277

264278
@since(2.0)
265279
def isCached(self, tableName):

python/pyspark/sql/udf.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,19 @@ def __call__(self, *cols):
138138
sc = SparkContext._active_spark_context
139139
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
140140

141+
def _set_name(self, name, returnType=StringType()):
142+
"""
143+
Updates the name of UserDefinedFunction.
144+
"""
145+
# reset _judf
146+
self._judf_placeholder = None
147+
self._returnType_placeholder = None
148+
self._name = name or (
149+
func.__name__ if hasattr(func, '__name__')
150+
else func.__class__.__name__)
151+
self._returnType = returnType
152+
return self
153+
141154
def _wrapped(self):
142155
"""
143156
Wrap this udf with a function and attach docstring from func
@@ -163,6 +176,10 @@ def wrapper(*args):
163176
wrapper.returnType = self.returnType
164177
wrapper.evalType = self.evalType
165178
wrapper.asNondeterministic = self.asNondeterministic
179+
wrapper._judf = self._judf
180+
wrapper._create_judf = self._create_judf
181+
wrapper._wrapped = self._wrapped
182+
wrapper._set_name = self._set_name
166183

167184
return wrapper
168185

0 commit comments

Comments
 (0)