Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support chained Python UDFs
  • Loading branch information
Davies Liu committed Mar 28, 2016
commit 024a82236f09a6b6ec09aeb41ca161e6c7c72dda
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
self.assertEqual(row[0], 4)
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def read_command(serializer, file):
return command


def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))


def main(infile, outfile):
try:
boot_time = time.time()
Expand Down Expand Up @@ -112,8 +117,7 @@ def main(infile, outfile):
if row_func is None:
row_func = f
else:
# chain multiple UDF together
row_func = lambda x: f(row_func(x))
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.PythonRunner
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

Expand All @@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
Expand All @@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)

val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
Expand All @@ -75,7 +89,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(Seq(udf.func), bufferSize, reuseWorker, true)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -29,13 +30,31 @@ import org.apache.spark.sql.catalyst.rules.Rule
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {

private def hasUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
}

private def canEvaluate(e: PythonUDF): Boolean = {
e.children match {
case Seq(u: PythonUDF) => canEvaluate(u)
case children => !children.exists(hasUDF)
}
}

private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
expr.collect {
case udf: PythonUDF if canEvaluate(udf) => udf
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan

case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the comments the explain our new strategy of extracting and evaluating python udfs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand Down