Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add test for data source and data source v2; Fix projection in EvalPy…
…thonExec; Move logic to ExtractPythonUDFs
  • Loading branch information
icexelloss committed Aug 27, 2018
commit cfd568e2fe429c7959264a759bc2fd0b34b03eea
29 changes: 20 additions & 9 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3367,21 +3367,32 @@ def test_ignore_column_of_all_nulls(self):
finally:
shutil.rmtree(path)

# SPARK-24721
def test_datasource_with_udf_filter_lit_input(self):
# SPARK-24721
import pandas as pd
import numpy as np
from pyspark.sql.functions import udf, pandas_udf, lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)
try:
from pyspark.sql.functions import udf, lit, col

self.spark.range(1).write.mode("overwrite").format('csv').save(path)
df = self.spark.read.csv(path)
# Test that filter with lit inputs works with data source
result1 = df.filter(udf(lambda x: False, 'boolean')(lit(1)))
result2 = df.filter(udf(lambda: False, 'boolean')())
filesource_df = self.spark.read.csv(path)
datasource_df = self.spark.read.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load()
datasource_v2_df = self.spark.read.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load()

filter1 = udf(lambda: False, 'boolean')()
filter2 = udf(lambda x: False, 'boolean')(lit(1))
filter3 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [filter1, filter2, filter3]:
result = df.filter(f)
result.explain(True)
self.assertEquals(0, result.count())

self.assertEquals(0, result1.count())
self.assertEquals(0, result2.count())
finally:
shutil.rmtree(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,10 @@ object FileSourceStrategy extends Strategy with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)

// SPARK-24721: Filter out Python UDFs, otherwise ExtractPythonUDF rule will throw exception
val validFilters = filters.filter(_.collectFirst{ case e: PythonUDF => e }.isEmpty)

// The attribute name of predicate could be different than the one in schema in case of
// case insensitive, we should change them to match the one in schema, so we do not need to
// worry about case sensitivity anymore.
val normalizedFilters = validFilters.map { e =>
val normalizedFilters = filters.map { e =>
e transform {
case a: AttributeReference =>
a.withName(l.output.find(_.semanticEquals(a)).get.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)
val projection = UnsafeProjection.create(allInputs, child.output)
val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})

// Add rows to queue to join later with the result.
val projectedRowIter = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
val unsafeRow = projection(inputRow)
queue.add(unsafeRow.asInstanceOf[UnsafeRow])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably another bug I found in testing this - If the input node to EvalPythonExec doesn't produce UnsafeRow, and cast here will fail.

I found this in testing when I pass in an test data source scan node, which produces GeneralInternalRow, will throw exception here.

I am happy to submit this as a separate patch if people think it's necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.. This seems to break existing tests. Need to look into it.

unsafeRow
}

val outputRowIterator = evaluate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec


/**
Expand Down Expand Up @@ -133,6 +134,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
// SPARK-24721: Ignore Python UDFs in DataSourceScan and DataSourceV2Scan
case plan: DataSourceScanExec => plan
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get rid of the logic previously in FileSourceStrategy to exclude PythonUDF in the filter in favor of this fix - I think this fix is cleaner.

case plan: DataSourceV2ScanExec => plan
case plan: SparkPlan => extract(plan)
}

Expand Down