Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5244aaf
[SPARK-22897][CORE] Expose stageAttemptId in TaskContext
advancedxy Jan 2, 2018
b96a213
[SPARK-22938] Assert that SQLConf.get is accessed only on the driver.
juliuszsompolski Jan 3, 2018
a05e85e
[SPARK-22934][SQL] Make optional clauses order insensitive for CREATE…
gatorsmile Jan 3, 2018
b962488
[SPARK-20236][SQL] dynamic partition overwrite
cloud-fan Jan 3, 2018
27c949d
[SPARK-22932][SQL] Refactor AnalysisContext
gatorsmile Jan 2, 2018
79f7263
[SPARK-22896] Improvement in String interpolation
chetkhatri Jan 3, 2018
a51212b
[SPARK-20960][SQL] make ColumnVector public
cloud-fan Jan 3, 2018
f51c8fd
[SPARK-22944][SQL] improve FoldablePropagation
cloud-fan Jan 4, 2018
1860a43
[SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, t…
felixcheung Jan 4, 2018
a7cfd6b
[SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent
yaooqinn Jan 4, 2018
eb99b8a
[SPARK-22945][SQL] add java UDF APIs in the functions object
cloud-fan Jan 4, 2018
1f5e354
[SPARK-22939][PYSPARK] Support Spark UDF in registerFunction
gatorsmile Jan 4, 2018
bcfeef5
[SPARK-22771][SQL] Add a missing return statement in Concat.checkInpu…
maropu Jan 4, 2018
cd92913
[SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for ex…
jerryshao Jan 4, 2018
bc4bef4
[SPARK-22850][CORE] Ensure queued events are delivered to all event q…
Jan 4, 2018
2ab4012
[SPARK-22948][K8S] Move SparkPodInitContainer to correct package.
Jan 4, 2018
84707f0
[SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-…
liyinan926 Jan 4, 2018
ea9da61
[SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly.
Jan 5, 2018
158f7e6
[SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt
juliuszsompolski Jan 5, 2018
145820b
[SPARK-22825][SQL] Fix incorrect results of Casting Array to String
maropu Jan 5, 2018
5b524cc
[SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed…
MrBago Jan 5, 2018
f9dcdbc
[SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode
liyinan926 Jan 5, 2018
fd4e304
[SPARK-22961][REGRESSION] Constant columns should generate QueryPlanC…
adrian-ionescu Jan 5, 2018
0a30e93
[SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on…
bersprockets Jan 5, 2018
d1f422c
[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator
jkbradley Jan 5, 2018
55afac4
[SPARK-22914][DEPLOY] Register history.ui.port
gerashegalov Jan 6, 2018
bf85301
[SPARK-22937][SQL] SQL elt output binary for binary inputs
maropu Jan 6, 2018
3e3e938
[SPARK-22960][K8S] Revert use of ARG base_image in images
liyinan926 Jan 6, 2018
7236914
[SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs…
icexelloss Jan 6, 2018
e6449e8
[SPARK-22793][SQL] Memory leak in Spark Thrift Server
Jan 6, 2018
0377755
[SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'Par…
fjh100456 Jan 6, 2018
b66700a
[SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic f…
HyukjinKwon Jan 6, 2018
f9e7b0c
[HOTFIX] Fix style checking failure
gatorsmile Jan 6, 2018
285d342
[SPARK-22973][SQL] Fix incorrect results of Casting Map to String
maropu Jan 7, 2018
516c0a1
Merge pull request #1 from apache/master
fjh100456 Jan 8, 2018
bd1a80a
Merge remote-tracking branch 'upstream/branch-2.3'
fjh100456 Jan 8, 2018
51f4418
Merge branch 'master' of https://github.com/fjh100456/spark
fjh100456 Jan 9, 2018
cf73803
Merge pull request #3 from apache/master
fjh100456 Apr 20, 2018
6515fb1
Merge remote-tracking branch 'origin/master'
fjh100456 Apr 20, 2018
0c39ead
Merge pull request #4 from apache/master
fjh100456 Aug 29, 2018
61a1028
Merge remote-tracking branch 'origin/master'
fjh100456 Aug 29, 2018
a98d1a1
[SPARK-21786][SQL][FOLLOWUP] Add compressionCodec test for CTAS
fjh100456 Aug 31, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-22939][PYSPARK] Support Spark UDF in registerFunction
## What changes were proposed in this pull request?
```Python
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType
random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
spark.catalog.registerFunction("random_udf", random_udf, StringType())
spark.sql("SELECT random_udf()").collect()
```

We will get the following error.
```
Py4JError: An error occurred while calling o29.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
	at py4j.Gateway.invoke(Gateway.java:274)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)
```

This PR is to support it.

## How was this patch tested?
WIP

Author: gatorsmile <[email protected]>

Closes #20137 from gatorsmile/registerFunction.

(cherry picked from commit 5aadbc9)
Signed-off-by: gatorsmile <[email protected]>
  • Loading branch information
gatorsmile committed Jan 4, 2018
commit 1f5e3540c7535ceaea66ebd5ee2f598e8b3ba1a5
27 changes: 22 additions & 5 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
so it can be used in SQL statements.
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)

# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=f.deterministic)
else:
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()

Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
so it can be used in SQL statements.
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)

Expand Down
49 changes: 35 additions & 14 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,41 @@ def test_udf2(self):
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_udf3(self):
twoargs = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
self.assertEqual(twoargs.deterministic, True)
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
self.assertEqual(random_udf1.deterministic, False)
[row] = self.spark.sql("SELECT randInt()").collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf1()).collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf()).collect()
self.assertEqual(row[0], 6)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
pydoc.render_doc(random_udf)
pydoc.render_doc(random_udf1)

def test_chained_udf(self):
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.spark.sql("SELECT double(1)").collect()
Expand Down Expand Up @@ -435,15 +470,6 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down Expand Up @@ -567,15 +593,13 @@ def test_read_multiple_orc_file(self):

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType

def filename(path):
return path
Expand Down Expand Up @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self):

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import StringType

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
Expand Down Expand Up @@ -1299,7 +1322,6 @@ def test_between_function(self):
df.filter(df.a.between(df.b, df.c)).collect())

def test_struct_type(self):
from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
Expand Down Expand Up @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self):
_parse_datatype_string("a INT, c DOUBLE"))

def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
Expand Down
21 changes: 14 additions & 7 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
)

# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
return udf_obj._wrapped()


Expand All @@ -67,8 +68,10 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func,
returnType=StringType(), name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF):
returnType=StringType(),
name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
Expand All @@ -92,7 +95,7 @@ def __init__(self, func,
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
self._deterministic = True
self.deterministic = deterministic

@property
def returnType(self):
Expand Down Expand Up @@ -130,14 +133,17 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf

def __call__(self, *cols):
judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

# This function is for improving the online help system in the interactive interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
# argument annotation. (See: SPARK-19161)
def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
Expand All @@ -162,7 +168,8 @@ def wrapper(*args):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
wrapper.asNondeterministic = self.asNondeterministic
wrapper.deterministic = self.deterministic
wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()

return wrapper

Expand All @@ -172,5 +179,5 @@ def asNondeterministic(self):

.. versionadded:: 2.3
"""
self._deterministic = False
self.deterministic = False
return self