Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,25 +1826,38 @@ class UserDefinedFunction(object):
def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
self._judf = self._create_judf(name)

def _create_judf(self, name):
# Stores UserDefinedPythonFunctions jobj, once initialized
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)

@property
def _judf(self):
# It is possible that concurrent access, to newly created UDF,
# will initialize multiple UserDefinedPythonFunctions.
# This is unlikely, doesn't affect correctness,
# and should have a minimal performance impact.
if self._judf_placeholder is None:
self._judf_placeholder = self._create_judf()
return self._judf_placeholder

def _create_judf(self):
from pyspark.sql import SparkSession
sc = SparkContext.getOrCreate()
wrapped_func = _wrap_function(sc, self.func, self.returnType)

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
if name is None:
f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)
self._name, wrapped_func, jdt)
return judf

def __call__(self, *cols):
judf = self._judf
sc = SparkContext._active_spark_context
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
return Column(jc)
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))


@since(1.3)
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,27 @@ def filename(path):
row2 = df2.select(sameText(df2['file'])).first()
self.assertTrue(row2[0].find("people.json") != -1)

def test_udf_defers_judf_initalization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good test but maybe a bit too focused on testing the implementation specifics?

Maybe it might more sense to also have a test which verifies creating a UDF doesn't create a SparkSession since that is the intended purposes (we don't really care about delaying the initialization of _judfy that much per-se but we do care about verifying that we don't eagerly create the SparkSession on import). What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but I have this impression, maybe incorrect, that we avoid creating new contexts to keep total execution time manageable. If you think this justifies a separate TestCase I am more than fine with that (SPARK-19224 and [PYSPARK] Python tests organization , right?).

If not, we could mock this, and put assert on the number of calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a seperate test case and would able to be pretty light weight since it doesn't need to create a SparkContext or anything which traditionally takes longer to set up. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk Separate case it is. As long as implementation is correct an overhead is negligible.

Copy link
Member Author

@zero323 zero323 Jan 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep these tests, to make sure that _judf is initialized when necessary.

# This is separate of UDFInitializationTests
# to avoid context initialization
# when udf is called

from pyspark.sql.functions import UserDefinedFunction

f = UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
f._judf_placeholder,
"judf should not be initialized before the first call."
)

self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")

self.assertIsNotNone(
f._judf_placeholder,
"judf should be initialized after UDF has been called."
)

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down Expand Up @@ -1947,6 +1968,29 @@ def test_sparksession_with_stopped_sparkcontext(self):
df.collect()


class UDFInitializationTests(unittest.TestCase):
Copy link
Member Author

@zero323 zero323 Jan 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And add a separate test case checking SparkContext and SparkSession state.

def tearDown(self):
if SparkSession._instantiatedSession is not None:
SparkSession._instantiatedSession.stop()

if SparkContext._active_spark_context is not None:
SparkContext._active_spark_contex.stop()

def test_udf_init_shouldnt_initalize_context(self):
from pyspark.sql.functions import UserDefinedFunction

UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
SparkContext._active_spark_context,
"SparkContext shouldn't be initialized when UserDefinedFunction is created."
)
self.assertIsNone(
SparkSession._instantiatedSession,
"SparkSession shouldn't be initialized when UserDefinedFunction is created."
)


class HiveContextSQLTests(ReusedPySparkTestCase):

@classmethod
Expand Down