diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d417303bb147..557eefd4ad70 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -35,8 +35,9 @@ import org.apache.spark.util._ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 - val SQL_PANDAS_GROUPED_UDF = 3 + val SQL_BATCHED_OPT_UDF = 2 + val SQL_PANDAS_UDF = 3 + val SQL_PANDAS_GROUPED_UDF = 4 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a0adeed99445..c26323fde346 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -85,8 +85,9 @@ class SpecialLengths(object): class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 - SQL_PANDAS_UDF = 2 - SQL_PANDAS_GROUPED_UDF = 3 + SQL_BATCHED_OPT_UDF = 2 + SQL_PANDAS_UDF = 3 + SQL_PANDAS_GROUPED_UDF = 4 class Serializer(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 685eebcafefb..2be59f0f6783 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -350,6 +350,38 @@ def some_func(col, param): res = data.select(pudf(data['number']).alias('plus_four')) self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf_with_conditional_expr_when(self): + from pyspark.sql.functions import col, udf, when + + df = self.sc.parallelize([Row(x=5), Row(x=0)]).toDF() + f = udf(lambda value: 10 // int(value), IntegerType()) + whenExpr1 = when((col('x') > 0), f(col('x'))) + + results1 = df.select(whenExpr1).collect() + self.assertEqual(results1[0][0], 2) + self.assertEqual(results1[1][0], None) + + whenExpr2 = when((col('x') <= 0), None).otherwise(f(col('x'))) + + results2 = df.select(whenExpr2).collect() + self.assertEqual(results2[0][0], 2) + self.assertEqual(results2[1][0], None) + + def test_udf_with_conditional_expr_if(self): + self.spark.createDataFrame(self.sc.parallelize([Row(a=0), Row(a=2)]))\ + .createOrReplaceTempView("test") + + self.spark.catalog.registerFunction("divideByVal", + lambda value: 10 // int(value), IntegerType()) + + results1 = self.spark.sql("SELECT if(a > 0, divideByVal(a), 0) FROM test").collect() + self.assertEqual(results1[0][0], 0) + self.assertEqual(results1[1][0], 5) + + results2 = self.spark.sql("SELECT if(a <= 0, 0, divideByVal(a)) FROM test").collect() + self.assertEqual(results2[0][0], 0) + self.assertEqual(results2[1][0], 5) + def test_udf(self): self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5e100e0a9a95..b90f87cb3a96 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -105,8 +105,14 @@ def read_single_udf(pickleSer, infile, eval_type): elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: # a groupby apply udf has already been wrapped under apply() return arg_offsets, row_func - else: + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_OPT_UDF: + udf = wrap_udf(row_func, return_type) + opt_udf = lambda *a: udf(*a[:-1]) if a[-1] is True else None + return arg_offsets, opt_udf + else: + raise Exception(("Unknown python evaluation type: %d") % (eval_type)) def read_udfs(pickleSer, infile, eval_type): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea..a8f1b7b487d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -28,11 +28,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} -/** - * A physical plan that evaluates a [[PythonUDF]] - */ -case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec(udfs, output, child) { +abstract class BatchEvalPythonExecBase( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: SparkPlan) extends EvalPythonExec(udfs, output, child) { + + protected val evalType: Int protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -69,7 +70,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi // Output iterator for results from Python. val outputIterator = new PythonUDFRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + funcs, bufferSize, reuseWorker, evalType, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -93,3 +94,11 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } } } + +/** + * A physical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends BatchEvalPythonExecBase(udfs, output, child) { + protected override val evalType: Int = PythonEvalType.SQL_BATCHED_UDF +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchOptEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchOptEvalPythonExec.scala new file mode 100644 index 000000000000..21fdac50ad97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchOptEvalPythonExec.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.DataType + +/** + * A physical plan that evaluates a [[PythonUDF]]. Different to [[BatchEvalPythonExec]], this plan + * overrides the way to compute argument offsets and adds conditional expressions into the end of + * the offsets of the udf, if any. On Python side, the udf can be optionally run depending on the + * evaluated values of conditional expressions. + */ +case class BatchOptEvalPythonExec( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: SparkPlan, + udfConditionsMap: Map[PythonUDF, Seq[Expression]]) + extends BatchEvalPythonExecBase(udfs, output, child) { + + protected override val evalType: Int = PythonEvalType.SQL_BATCHED_OPT_UDF + + protected override def computeArgOffsets( + inputs: Seq[Seq[Expression]], + allInputs: ArrayBuffer[Expression], + dataTypes: ArrayBuffer[DataType]): Array[Array[Int]] = { + inputs.zipWithIndex.map { case (input, idx) => + var funcArgs = input.map(mapExpressionIntoFuncInputs(_, allInputs, dataTypes)).toArray + udfConditionsMap.get(udfs(idx)).foreach { conditions => + conditions.reduceOption(Or).foreach { cond => + val condArgOffset = mapExpressionIntoFuncInputs(cond, allInputs, dataTypes) + funcArgs = funcArgs :+ condArgOffset + } + } + funcArgs + }.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1..bef99f75fe7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -85,6 +85,36 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil schema: StructType, context: TaskContext): Iterator[InternalRow] + private def preparePyFuncsAndArgOffsets( + allInputs: ArrayBuffer[Expression], + dataTypes: ArrayBuffer[DataType]): (Seq[ChainedPythonFunctions], Array[Array[Int]]) = { + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + val argOffsets = computeArgOffsets(inputs, allInputs, dataTypes) + (pyFuncs, argOffsets) + } + + protected def mapExpressionIntoFuncInputs( + expr: Expression, + allInputs: ArrayBuffer[Expression], + dataTypes: ArrayBuffer[DataType]): Int = { + if (allInputs.exists(_.semanticEquals(expr))) { + allInputs.indexWhere(_.semanticEquals(expr)) + } else { + allInputs += expr + dataTypes += expr.dataType + allInputs.length - 1 + } + } + + protected def computeArgOffsets( + inputs: Seq[Seq[Expression]], + allInputs: ArrayBuffer[Expression], + dataTypes: ArrayBuffer[DataType]): Array[Array[Int]] = { + inputs.map { input => + input.map(mapExpressionIntoFuncInputs(_, allInputs, dataTypes)).toArray + }.toArray + } + protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -101,22 +131,10 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil queue.close() } - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - - // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } - }.toArray - }.toArray + val (pyFuncs, argOffsets) = preparePyFuncsAndArgOffsets(allInputs, dataTypes) + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index d6825369f737..7d0cd66f3db3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -116,6 +116,77 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { case plan: SparkPlan => extract(plan) } + private def pickUDFIntoMap( + expr: Expression, + condition: Expression, + exprMap: mutable.HashMap[PythonUDF, Seq[Expression]]): Unit = { + expr.foreachUp { + case udf: PythonUDF => exprMap.update(udf, exprMap.getOrElse(udf, Seq()) :+ condition) + case _ => + } + } + + private def updateConditionForUDFInBranch( + branches: Seq[(Expression, Expression)], + exprMap: mutable.HashMap[PythonUDF, Seq[Expression]]): Unit = { + branches.foreach(branch => pickUDFIntoMap(branch._2, branch._1, exprMap)) + } + + private def updateConditionForUDFInElse( + branches: Seq[(Expression, Expression)], + elseValue: Expression, + exprMap: mutable.HashMap[PythonUDF, Seq[Expression]]): Unit = { + assert(branches.length > 0) + + val elseCond = branches.map(_._1).reduce(Or) + pickUDFIntoMap(elseValue, Not(elseCond), exprMap) + } + + private def updateConditionForUDFInCaseWhen( + branches: Seq[(Expression, Expression)], + elseValue: Option[Expression], + exprMap: mutable.HashMap[PythonUDF, Seq[Expression]]): Unit = { + updateConditionForUDFInBranch(branches, exprMap) + elseValue.foreach { elseExpr => + updateConditionForUDFInElse(branches, elseExpr, exprMap) + } + } + + /** + * Extracts the conditions associated with PythonUDFs. + * Not all PythonUDFs need to be evaluated. For example, for a case when expression like + * `when(x > 1, pyUDF(x)).when(x > 2, pyUDF2(x))`, we don't need to evaluate two PythonUDFs + * for every row. Besides performance effect, under some cases, early evaluation of all + * PythonUDFs can cause failure, e.g., a PythonUDF that should divide by an expression when + * the value of expression is more than zero. + * + * Returns a map in which the value of a PythonUDF key is the sequence of boolean expressions + * that are the requirement to run the PythonUDF. + */ + private def extractConditionForUDF( + expressions: Seq[Expression], + udfs: Seq[PythonUDF]): mutable.HashMap[PythonUDF, Seq[Expression]] = { + val conditionMap = mutable.HashMap[PythonUDF, Seq[Expression]]() + expressions.map { expr => + expr.foreachUp { + case e @ CaseWhenCodegen(branches, elseValue) + if branches.exists(x => hasPythonUDF(x._2)) || + elseValue.map(hasPythonUDF).getOrElse(false) => + updateConditionForUDFInCaseWhen(branches, elseValue, conditionMap) + case e @ CaseWhen(branches, elseValue) + if branches.exists(x => hasPythonUDF(x._2)) || + elseValue.map(hasPythonUDF).getOrElse(false) => + updateConditionForUDFInCaseWhen(branches, elseValue, conditionMap) + case If(predicate, trueValue, falseValue) + if hasPythonUDF(trueValue) || hasPythonUDF(falseValue) => + pickUDFIntoMap(trueValue, predicate, conditionMap) + pickUDFIntoMap(falseValue, Not(predicate), conditionMap) + case _ => + } + } + conditionMap + } + /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ @@ -127,6 +198,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // If there aren't any, we are done. plan } else { + val udfConditionMap = extractConditionForUDF(plan.expressions, udfs) + val attributeMap = mutable.HashMap[PythonUDF, Expression]() val splitFilter = trySplitFilter(plan) // Rewrite the child that has the input required for the UDF @@ -136,6 +209,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) } + // If any UDFs to evaluate are used with conditional expressions. + val foundConditionalUdfs = validUdfs.exists(udfConditionMap.contains(_)) + if (validUdfs.nonEmpty) { if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { throw new IllegalArgumentException("Can not use grouped vectorized UDFs") @@ -148,8 +224,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) - case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => + case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty && !foundConditionalUdfs => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => + BatchOptEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child, + udfConditionMap.toMap) case _ => throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") }