Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move ExtractPythonUDFs to end of optimize stage
  • Loading branch information
icexelloss committed Aug 27, 2018
commit d440cbffee135da42a54da95388b01cf17ab16df
20 changes: 9 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,15 +3388,14 @@ def test_datasource_with_udf_filter_lit_input(self):
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load()
# TODO: Enable data source v2 after SPARK-25213 is fixed
# datasource_v2_df = self.spark.read \
# .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
# .load()
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load()

filter1 = udf(lambda: False, 'boolean')()
filter2 = udf(lambda x: False, 'boolean')(lit(1))

for df in [filesource_df, datasource_df]:
for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [filter1, filter2]:
result = df.filter(f)
self.assertEquals(0, result.count())
Expand Down Expand Up @@ -5309,7 +5308,7 @@ def f3(x):
# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf_filter_lit_input(self):
# Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pandas UDF
# Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pantestdas UDF
# This needs to a separate test because Arrow dependency is optional
import pandas as pd
import numpy as np
Expand All @@ -5323,14 +5322,13 @@ def test_datasource_with_udf_filter_lit_input(self):
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load()
# TODO: Enable data source v2 after SPARK-25213 is fixed
# datasource_v2_df = self.spark.read \
# .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
# .load()
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load()

f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))

for df in [filesource_df, datasource_df]:
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.filter(f)
self.assertEquals(0, result.count())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}

class SparkOptimizer(
catalog: SessionCatalog,
Expand All @@ -31,7 +31,8 @@ class SparkOptimizer(

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Extract Python UDFs", Once,
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks weird to add this rule in our optimizer batch. We need at least some comments to explain the reason in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we already have ExtractPythonUDFFromAggregate here...

Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
postHocOptimizationBatches :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SparkPlanner(
override def strategies: Seq[Strategy] =
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
PythonEvals ::
DataSourceV2Strategy ::
FileSourceStrategy ::
DataSourceStrategy(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Strategy to convert EvalPython logical operator to physical operator.
*/
object PythonEvals extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ArrowEvalPython(udfs, output, child) =>
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
case BatchEvalPython(udfs, output, child) =>
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
case _ =>
Nil
}
}

object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
}

/**
* A physical plan that evaluates a [[PythonUDF]],
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec


/**
Expand Down Expand Up @@ -94,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {

private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
Expand Down Expand Up @@ -133,17 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
// SPARK-24721: Ignore Python UDFs in DataSourceScan and DataSourceV2Scan
case plan: DataSourceScanExec => plan
case plan: DataSourceV2ScanExec => plan
case plan: SparkPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: LogicalPlan): LogicalPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
Expand All @@ -155,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val prunedChildren = plan.children.map { child =>
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
if (allNeededOutput.length != child.output.length) {
ProjectExec(allNeededOutput, child)
Project(allNeededOutput, child)
} else {
child
}
Expand Down Expand Up @@ -184,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
Expand All @@ -213,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val newPlan = extract(rewritten)
if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
ProjectExec(plan.output, newPlan)
Project(plan.output, newPlan)
} else {
newPlan
}
Expand All @@ -222,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
// that are all deterministic.
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
plan match {
case filter: FilterExec =>
case filter: Filter =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
Expand Down