-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16134][SQL] optimizer rules for typed filter #13846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,8 +110,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) | |
| Batch("Decimal Optimizations", fixedPoint, | ||
| DecimalAggregates) :: | ||
| Batch("Typed Filter Optimization", fixedPoint, | ||
| EmbedSerializerInFilter, | ||
| RemoveAliasOnlyProject) :: | ||
| CombineTypedFilters) :: | ||
| Batch("LocalRelation", fixedPoint, | ||
| ConvertToLocalRelation) :: | ||
| Batch("OptimizeCodegen", Once, | ||
|
|
@@ -206,15 +205,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { | |
| object EliminateSerialization extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case d @ DeserializeToObject(_, _, s: SerializeFromObject) | ||
| if d.outputObjectType == s.inputObjectType => | ||
| if d.outputObjAttr.dataType == s.inputObjAttr.dataType => | ||
| // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. | ||
| // We will remove it later in RemoveAliasOnlyProject rule. | ||
| val objAttr = | ||
| Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) | ||
| val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) | ||
| Project(objAttr :: Nil, s.child) | ||
|
|
||
| case a @ AppendColumns(_, _, _, s: SerializeFromObject) | ||
| if a.deserializer.dataType == s.inputObjectType => | ||
| if a.deserializer.dataType == s.inputObjAttr.dataType => | ||
| AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) | ||
|
|
||
| // If there is a `SerializeFromObject` under typed filter and its input object type is same with | ||
| // the typed filter's deserializer, we can convert typed filter to normal filter without | ||
| // deserialization in condition, and push it down through `SerializeFromObject`. | ||
| // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, | ||
| // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. | ||
| case f @ TypedFilter(_, _, s: SerializeFromObject) | ||
| if f.deserializer.dataType == s.inputObjAttr.dataType => | ||
| s.copy(child = f.withObject(s.child)) | ||
|
||
|
|
||
| // If there is a `DeserializeToObject` upon typed filter and its output object type is same with | ||
| // the typed filter's deserializer, we can convert typed filter to normal filter without | ||
| // deserialization in condition, and pull it up through `DeserializeToObject`. | ||
| // e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization, | ||
| // but `ds.filter(...).as[AnotherType].map(...)` can not be optimized. | ||
| case d @ DeserializeToObject(_, _, f: TypedFilter) | ||
| if d.outputObjAttr.dataType == f.deserializer.dataType => | ||
| f.withObject(d.copy(child = f.child)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1645,55 +1662,31 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic | |
| } | ||
|
|
||
| /** | ||
| * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a | ||
| * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed | ||
| * the deserializer in filter condition to save the extra serialization at last. | ||
| * Combines all adjacent [[TypedFilter]]s, which operate on same type object in condition, into a | ||
| * single [[Filter]]. | ||
| */ | ||
| object EmbedSerializerInFilter extends Rule[LogicalPlan] { | ||
| object CombineTypedFilters extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) | ||
| // SPARK-15632: Conceptually, filter operator should never introduce schema change. This | ||
| // optimization rule also relies on this assumption. However, Dataset typed filter operator | ||
| // does introduce schema changes in some cases. Thus, we only enable this optimization when | ||
| // | ||
| // 1. either input and output schemata are exactly the same, or | ||
| // 2. both input and output schemata are single-field schema and share the same type. | ||
| // | ||
| // The 2nd case is included because encoders for primitive types always have only a single | ||
| // field with hard-coded field name "value". | ||
| // TODO Cleans this up after fixing SPARK-15632. | ||
| if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => | ||
|
|
||
| val numObjects = condition.collect { | ||
| case a: Attribute if a == d.output.head => a | ||
| }.length | ||
|
|
||
| if (numObjects > 1) { | ||
| // If the filter condition references the object more than one times, we should not embed | ||
| // deserializer in it as the deserialization will happen many times and slow down the | ||
| // execution. | ||
| // TODO: we can still embed it if we can make sure subexpression elimination works here. | ||
| s | ||
| case t @ TypedFilter(_, deserializer, child) => | ||
| val filters = collectTypedFiltersOnSameTypeObj(child, deserializer.dataType, ArrayBuffer(t)) | ||
| if (filters.length > 1) { | ||
| val objHolder = BoundReference(0, deserializer.dataType, nullable = false) | ||
| val condition = filters.map(_.getCondition(objHolder)).reduce(And) | ||
| Filter(ReferenceToExpressions(condition, deserializer :: Nil), filters.last.child) | ||
| } else { | ||
| val newCondition = condition transform { | ||
| case a: Attribute if a == d.output.head => d.deserializer | ||
| } | ||
| val filter = Filter(newCondition, d.child) | ||
|
|
||
| // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. | ||
| // We will remove it later in RemoveAliasOnlyProject rule. | ||
| val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => | ||
| Alias(fout, fout.name)(exprId = sout.exprId) | ||
| } | ||
| Project(objAttrs, filter) | ||
| t | ||
| } | ||
| } | ||
|
|
||
| def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { | ||
| (lhs, rhs) match { | ||
| case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType | ||
| case _ => false | ||
| } | ||
| @tailrec | ||
| private def collectTypedFiltersOnSameTypeObj( | ||
| plan: LogicalPlan, | ||
| objType: DataType, | ||
| filters: ArrayBuffer[TypedFilter]): Array[TypedFilter] = plan match { | ||
| case t: TypedFilter if t.deserializer.dataType == objType => | ||
| filters += t | ||
|
||
| collectTypedFiltersOnSameTypeObj(t.child, objType, filters) | ||
| case _ => filters.toArray | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,15 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.plans.logical | ||
|
|
||
| import scala.language.existentials | ||
|
|
||
| import org.apache.spark.api.java.function.FilterFunction | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.sql.{Encoder, Row} | ||
| import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer | ||
| import org.apache.spark.sql.catalyst.encoders._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.objects.Invoke | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| object CatalystSerde { | ||
|
|
@@ -45,13 +49,11 @@ object CatalystSerde { | |
| */ | ||
| trait ObjectProducer extends LogicalPlan { | ||
| // The attribute that reference to the single object field this operator outputs. | ||
| protected def outputObjAttr: Attribute | ||
| def outputObjAttr: Attribute | ||
|
|
||
| override def output: Seq[Attribute] = outputObjAttr :: Nil | ||
|
|
||
| override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) | ||
|
|
||
| def outputObjectType: DataType = outputObjAttr.dataType | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode { | |
| // This operator always need all columns of its child, even it doesn't reference to. | ||
| override def references: AttributeSet = child.outputSet | ||
|
|
||
| def inputObjectType: DataType = child.output.head.dataType | ||
| def inputObjAttr: Attribute = child.output.head | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -167,6 +169,43 @@ case class MapElements( | |
| outputObjAttr: Attribute, | ||
| child: LogicalPlan) extends ObjectConsumer with ObjectProducer | ||
|
|
||
| object TypedFilter { | ||
| def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { | ||
| TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A relation produced by applying `func` to each element of the `child` and filter them by the | ||
| * resulting boolean value. | ||
| * | ||
| * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding | ||
| * the input row to object and apply the given function with decoded object. However we need the | ||
| * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write | ||
| * optimizer rules. | ||
| */ | ||
| case class TypedFilter( | ||
| func: AnyRef, | ||
| deserializer: Expression, | ||
| child: LogicalPlan) extends UnaryNode { | ||
|
|
||
| override def output: Seq[Attribute] = child.output | ||
|
|
||
| def withObject(obj: LogicalPlan): Filter = { | ||
|
||
| assert(obj.output.length == 1) | ||
| Filter(getCondition(obj.output.head), obj) | ||
| } | ||
|
|
||
| def getCondition(input: Expression): Expression = { | ||
|
||
| val (funcClass, methodName) = func match { | ||
| case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" | ||
| case _ => classOf[Any => Boolean] -> "apply" | ||
| } | ||
| val funcObj = Literal.create(func, ObjectType(funcClass)) | ||
| Invoke(funcObj, methodName, BooleanType, input :: Nil) | ||
| } | ||
| } | ||
|
|
||
| /** Factory for constructing new `AppendColumn` nodes. */ | ||
| object AppendColumns { | ||
| def apply[T : Encoder, U : Encoder]( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,54 +23,111 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer | |
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} | ||
| import org.apache.spark.sql.catalyst.expressions.{BoundReference, ReferenceToExpressions} | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.types.BooleanType | ||
| import org.apache.spark.sql.types.{BooleanType, ObjectType} | ||
|
|
||
| class TypedFilterOptimizationSuite extends PlanTest { | ||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("EliminateSerialization", FixedPoint(50), | ||
| EliminateSerialization) :: | ||
| Batch("EmbedSerializerInFilter", FixedPoint(50), | ||
| EmbedSerializerInFilter) :: Nil | ||
| Batch("CombineTypedFilters", FixedPoint(50), | ||
| CombineTypedFilters) :: Nil | ||
| } | ||
|
|
||
| implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() | ||
|
|
||
| test("back to back filter") { | ||
| test("filter after serialize") { | ||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f1 = (i: (Int, Int)) => i._1 > 0 | ||
| val f2 = (i: (Int, Int)) => i._2 > 0 | ||
| val f = (i: (Int, Int)) => i._1 > 0 | ||
|
|
||
| val query = input.filter(f1).filter(f2).analyze | ||
| val query = input | ||
| .deserialize[(Int, Int)] | ||
| .serialize[(Int, Int)] | ||
| .filter(f).analyze | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
|
|
||
| val expected = input.deserialize[(Int, Int)] | ||
| .where(callFunction(f1, BooleanType, 'obj)) | ||
| .select('obj.as("obj")) | ||
| .where(callFunction(f2, BooleanType, 'obj)) | ||
| val expected = input | ||
| .deserialize[(Int, Int)] | ||
| .where(callFunction(f, BooleanType, 'obj)) | ||
| .serialize[(Int, Int)].analyze | ||
|
|
||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules | ||
| // for typed filters. | ||
| ignore("embed deserializer in typed filter condition if there is only one filter") { | ||
| test("filter after serialize with object change") { | ||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f = (i: OtherTuple) => i._1 > 0 | ||
|
|
||
| val query = input | ||
| .deserialize[(Int, Int)] | ||
| .serialize[(Int, Int)] | ||
| .filter(f).analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
|
|
||
| test("filter before deserialize") { | ||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f = (i: (Int, Int)) => i._1 > 0 | ||
|
|
||
| val query = input.filter(f).analyze | ||
| val query = input | ||
| .filter(f) | ||
| .deserialize[(Int, Int)] | ||
| .serialize[(Int, Int)].analyze | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
|
|
||
| val expected = input | ||
| .deserialize[(Int, Int)] | ||
| .where(callFunction(f, BooleanType, 'obj)) | ||
| .serialize[(Int, Int)].analyze | ||
|
|
||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("filter before deserialize with object change") { | ||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f = (i: OtherTuple) => i._1 > 0 | ||
|
|
||
| val query = input | ||
| .filter(f) | ||
| .deserialize[(Int, Int)] | ||
| .serialize[(Int, Int)].analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
|
|
||
| test("back to back filter") { | ||
|
||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f1 = (i: (Int, Int)) => i._1 > 0 | ||
| val f2 = (i: (Int, Int)) => i._2 > 0 | ||
|
|
||
| val query = input.filter(f1).filter(f2).analyze | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
|
|
||
| val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) | ||
| val condition = callFunction(f, BooleanType, deserializer) | ||
| val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze | ||
| val boundReference = BoundReference(0, ObjectType(classOf[(Int, Int)]), nullable = false) | ||
| val callFunc1 = callFunction(f1, BooleanType, boundReference) | ||
| val callFunc2 = callFunction(f2, BooleanType, boundReference) | ||
| val condition = ReferenceToExpressions(callFunc2 && callFunc1, deserializer :: Nil) | ||
| val expected = input.where(condition).analyze | ||
|
|
||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("back to back filter with object change") { | ||
|
||
| val input = LocalRelation('_1.int, '_2.int) | ||
| val f1 = (i: (Int, Int)) => i._1 > 0 | ||
| val f2 = (i: OtherTuple) => i._2 > 0 | ||
|
|
||
| val query = input.filter(f1).filter(f2).analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is actually a bug fix? Before we can only use a single
BoundReferenceasresult, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup