Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2a1cd27
optimization + test
bogdanrdc Aug 23, 2018
421ee20
debug benchmark + early batch
bogdanrdc Aug 23, 2018
d7e49e7
revert benchmark
bogdanrdc Aug 23, 2018
ba6d91e
Merge remote-tracking branch 'upstream/master' into local-relation-fi…
bogdanrdc Aug 24, 2018
326e5d7
test fix
bogdanrdc Aug 24, 2018
4263bd2
[SPARK-25073][YARN] AM and Executor Memory validation message is not …
sujith71955 Aug 24, 2018
f84d256
[SPARK-25214][SS] Fix the issue that Kafka v2 source may return dupli…
zsxwing Aug 24, 2018
c721895
[SPARK-25174][YARN] Limit the size of diagnostic message for am to un…
yaooqinn Aug 24, 2018
f8536e3
[SPARK-25234][SPARKR] avoid integer overflow in parallelize
mengxr Aug 24, 2018
af6a91e
Correct missing punctuation in the documentation
Aug 25, 2018
c613c6b
[MINOR] Fix Scala 2.12 build
dbtsai Aug 25, 2018
ee1c0e8
[SPARK-24688][EXAMPLES] Modify the comments about LabeledPoint
huangweizhe123 Aug 25, 2018
77fb55e
[SPARK-25214][SS][FOLLOWUP] Fix the issue that Kafka v2 source may re…
zsxwing Aug 25, 2018
b00824c
[SPARK-23792][DOCS] Documentation improvements for datetime functions
abradbury Aug 26, 2018
ee6cb6c
[SPARK-23698][PYTHON][FOLLOWUP] Resolve undefiend names in setup.py
HyukjinKwon Aug 27, 2018
c129176
[SPARK-19355][SQL][FOLLOWUP] Remove the child.outputOrdering check in…
viirya Aug 27, 2018
368b42f
[SPARK-24978][SQL] Add spark.sql.fast.hash.aggregate.row.max.capacity…
heary-cao Aug 27, 2018
0378b1f
[SPARK-25249][CORE][TEST] add a unit test for OpenHashMap
10110346 Aug 27, 2018
d5a953a
[SPARK-24882][FOLLOWUP] Fix flaky synchronization in Kafka tests.
jose-torres Aug 27, 2018
3598483
[SPARK-24149][YARN][FOLLOW-UP] Only get the delegation tokens of the …
wangyum Aug 27, 2018
397fa62
[SPARK-24090][K8S] Update running-on-kubernetes.md
liyinan926 Aug 27, 2018
dcd001b
[SPARK-24721][SQL] Exclude Python UDFs filters in FileSourceStrategy
icexelloss Aug 28, 2018
b23538b
[SPARK-25218][CORE] Fix potential resource leaks in TransportServer a…
zsxwing Aug 28, 2018
f769a94
[SPARK-25005][SS] Support non-consecutive offsets for Kafka
zsxwing Aug 28, 2018
68c41ff
comment
bogdanrdc Aug 28, 2018
dad6a7f
Merge remote-tracking branch 'upstream/master' into local-relation-fi…
bogdanrdc Aug 28, 2018
cb067c3
Merge remote-tracking branch 'upstream/master' into local-relation-fi…
bogdanrdc Aug 28, 2018
d552cc1
space
bogdanrdc Aug 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-24721][SQL] Exclude Python UDFs filters in FileSourceStrategy
## What changes were proposed in this pull request?
The PR excludes Python UDFs filters in FileSourceStrategy so that they don't ExtractPythonUDF rule to throw exception. It doesn't make sense to pass Python UDF filters in FileSourceStrategy anyway because they cannot be used as push down filters.

## How was this patch tested?
Add a new regression test

Closes #22104 from icexelloss/SPARK-24721-udf-filter.

Authored-by: Li Jin <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
icexelloss authored and bogdanrdc committed Aug 28, 2018
commit dcd001b74d50830594694df73b5502c3e17647bb
94 changes: 94 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,16 @@
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

_test_not_compiled_message = None
try:
from pyspark.sql.utils import require_test_compiled
require_test_compiled()
except Exception as e:
_test_not_compiled_message = _exception_message(e)

_have_pandas = _pandas_requirement_message is None
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
Expand Down Expand Up @@ -3367,6 +3375,47 @@ def test_ignore_column_of_all_nulls(self):
finally:
shutil.rmtree(path)

# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf(self):
from pyspark.sql.functions import udf, lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load().toDF('i')
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load().toDF('i', 'j')

c1 = udf(lambda x: x + 1, 'int')(lit(1))
c2 = udf(lambda x: x + 1, 'int')(col('i'))

f1 = udf(lambda x: False, 'boolean')(lit(1))
f2 = udf(lambda x: False, 'boolean')(col('i'))

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
finally:
shutil.rmtree(path)

def test_repr_behaviors(self):
import re
pattern = re.compile(r'^ *\|', re.MULTILINE)
Expand Down Expand Up @@ -5269,6 +5318,51 @@ def f3(x):

self.assertEquals(expected.collect(), df1.collect())

# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf(self):
# Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
# This needs to a separate test because Arrow dependency is optional
import pandas as pd
import numpy as np
from pyspark.sql.functions import pandas_udf, lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load().toDF('i')
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load().toDF('i', 'j')

c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))

f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
finally:
shutil.rmtree(path)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ def require_minimum_pyarrow_version():
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))


def require_test_compiled():
""" Raise Exception if test classes are not compiled
"""
import os
import glob
try:
spark_home = os.environ['SPARK_HOME']
except KeyError:
raise RuntimeError('SPARK_HOME is not defined in environment')

test_class_path = os.path.join(
spark_home, 'sql', 'core', 'target', '*', 'test-classes')
paths = glob.glob(test_class_path)

if len(paths) == 0:
raise RuntimeError(
"%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)


class ForeachBatchFunction(object):
"""
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}

class SparkOptimizer(
catalog: SessionCatalog,
Expand All @@ -31,7 +31,8 @@ class SparkOptimizer(

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Extract Python UDFs", Once,
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
postHocOptimizationBatches :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SparkPlanner(
override def strategies: Seq[Strategy] =
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
PythonEvals ::
DataSourceV2Strategy ::
FileSourceStrategy ::
DataSourceStrategy(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Strategy to convert EvalPython logical operator to physical operator.
*/
object PythonEvals extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ArrowEvalPython(udfs, output, child) =>
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
case BatchEvalPython(udfs, output, child) =>
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
case _ =>
Nil
}
}

object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
}

/**
* A physical plan that evaluates a [[PythonUDF]],
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}


/**
Expand Down Expand Up @@ -93,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {

private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
Expand Down Expand Up @@ -132,14 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: LogicalPlan): LogicalPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
Expand All @@ -151,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val prunedChildren = plan.children.map { child =>
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
if (allNeededOutput.length != child.output.length) {
ProjectExec(allNeededOutput, child)
Project(allNeededOutput, child)
} else {
child
}
Expand Down Expand Up @@ -180,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
Expand All @@ -209,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val newPlan = extract(rewritten)
if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
ProjectExec(plan.output, newPlan)
Project(plan.output, newPlan)
} else {
newPlan
}
Expand All @@ -218,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
// that are all deterministic.
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
plan match {
case filter: FilterExec =>
case filter: Filter =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._

class DefaultSource extends SimpleScanSource

// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
class SimpleScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProv
}
}


// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {

class ReadSupport extends SimpleReadSupport {
Expand Down