diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 66d993a81488..02c2350dc2d6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1826,25 +1826,38 @@ class UserDefinedFunction(object): def __init__(self, func, returnType, name=None): self.func = func self.returnType = returnType - self._judf = self._create_judf(name) - - def _create_judf(self, name): + # Stores UserDefinedPythonFunctions jobj, once initialized + self._judf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') + else func.__class__.__name__) + + @property + def _judf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judf_placeholder is None: + self._judf_placeholder = self._create_judf() + return self._judf_placeholder + + def _create_judf(self): from pyspark.sql import SparkSession - sc = SparkContext.getOrCreate() - wrapped_func = _wrap_function(sc, self.func, self.returnType) + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) - if name is None: - f = self.func - name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - name, wrapped_func, jdt) + self._name, wrapped_func, jdt) return judf def __call__(self, *cols): + judf = self._judf sc = SparkContext._active_spark_context - jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) - return Column(jc) + return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a88e5a1cfb3c..2fea4ac41f0d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -468,6 +468,27 @@ def filename(path): row2 = df2.select(sameText(df2['file'])).first() self.assertTrue(row2[0].find("people.json") != -1) + def test_udf_defers_judf_initalization(self): + # This is separate of UDFInitializationTests + # to avoid context initialization + # when udf is called + + from pyspark.sql.functions import UserDefinedFunction + + f = UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + f._judf_placeholder, + "judf should not be initialized before the first call." + ) + + self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") + + self.assertIsNotNone( + f._judf_placeholder, + "judf should be initialized after UDF has been called." + ) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -1947,6 +1968,29 @@ def test_sparksession_with_stopped_sparkcontext(self): df.collect() +class UDFInitializationTests(unittest.TestCase): + def tearDown(self): + if SparkSession._instantiatedSession is not None: + SparkSession._instantiatedSession.stop() + + if SparkContext._active_spark_context is not None: + SparkContext._active_spark_contex.stop() + + def test_udf_init_shouldnt_initalize_context(self): + from pyspark.sql.functions import UserDefinedFunction + + UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + SparkContext._active_spark_context, + "SparkContext shouldn't be initialized when UserDefinedFunction is created." + ) + self.assertIsNone( + SparkSession._instantiatedSession, + "SparkSession shouldn't be initialized when UserDefinedFunction is created." + ) + + class HiveContextSQLTests(ReusedPySparkTestCase): @classmethod