Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
optimizer rules for typed filter
  • Loading branch information
cloud-fan committed Jun 29, 2016
commit 8adf6027076c6c8e69d3c5e5a587d48a6557b288
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,7 @@ package object dsl {

def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)

def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
val deserialized = logicalPlan.deserialize[T]
val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
Filter(condition, deserialized).serialize[T]
}
def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)

def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
var maxOrdinal = -1
result foreach {
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
case _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is actually a bug fix? Before we can only use a single BoundReference as result, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

}
if (maxOrdinal > children.length) {
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
EmbedSerializerInFilter,
RemoveAliasOnlyProject) ::
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
Expand Down Expand Up @@ -206,15 +205,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
val objAttr =
Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)

case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)

// If there is a `SerializeFromObject` under typed filter and its input object type is same with
// the typed filter's deserializer, we can convert typed filter to normal filter without
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
case f @ TypedFilter(_, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObject(s.child))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also push down if TypedFilter( Filter( SerializeFromObject(child) ) ) into Filter( SerializeFromObject( TypedFilter(child) ) ).
e.g. ds.map(...).filter(byExpr).filter(byFunc).
What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's true, and Filter can be any other unary operators whose output is derived from its child, e.g. Sort.

However, I don't think ds.map(...).filter(byExpr).filter(byFunc) is a common case, i.e. mixing typed and untyped operations interlaced. If there is an easy and general way to optimize it, I'm happy to have it, or I'd like to leave it.

what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think mixing typed and untyped is not a common case, but I don't have any idea to optimize easy and general way so I think we can leave it for now.


// If there is a `DeserializeToObject` upon typed filter and its output object type is same with
// the typed filter's deserializer, we can convert typed filter to normal filter without
// deserialization in condition, and pull it up through `DeserializeToObject`.
// e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization,
// but `ds.filter(...).as[AnotherType].map(...)` can not be optimized.
case d @ DeserializeToObject(_, _, f: TypedFilter)
if d.outputObjAttr.dataType == f.deserializer.dataType =>
f.withObject(d.copy(child = f.child))
}
}

Expand Down Expand Up @@ -1645,55 +1662,31 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
}

/**
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
* the deserializer in filter condition to save the extra serialization at last.
* Combines all adjacent [[TypedFilter]]s, which operate on same type object in condition, into a
* single [[Filter]].
*/
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
// SPARK-15632: Conceptually, filter operator should never introduce schema change. This
// optimization rule also relies on this assumption. However, Dataset typed filter operator
// does introduce schema changes in some cases. Thus, we only enable this optimization when
//
// 1. either input and output schemata are exactly the same, or
// 2. both input and output schemata are single-field schema and share the same type.
//
// The 2nd case is included because encoders for primitive types always have only a single
// field with hard-coded field name "value".
// TODO Cleans this up after fixing SPARK-15632.
if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>

val numObjects = condition.collect {
case a: Attribute if a == d.output.head => a
}.length

if (numObjects > 1) {
// If the filter condition references the object more than one times, we should not embed
// deserializer in it as the deserialization will happen many times and slow down the
// execution.
// TODO: we can still embed it if we can make sure subexpression elimination works here.
s
case t @ TypedFilter(_, deserializer, child) =>
val filters = collectTypedFiltersOnSameTypeObj(child, deserializer.dataType, ArrayBuffer(t))
if (filters.length > 1) {
val objHolder = BoundReference(0, deserializer.dataType, nullable = false)
val condition = filters.map(_.getCondition(objHolder)).reduce(And)
Filter(ReferenceToExpressions(condition, deserializer :: Nil), filters.last.child)
} else {
val newCondition = condition transform {
case a: Attribute if a == d.output.head => d.deserializer
}
val filter = Filter(newCondition, d.child)

// Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
val objAttrs = filter.output.zip(s.output).map { case (fout, sout) =>
Alias(fout, fout.name)(exprId = sout.exprId)
}
Project(objAttrs, filter)
t
}
}

def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
(lhs, rhs) match {
case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
case _ => false
}
@tailrec
private def collectTypedFiltersOnSameTypeObj(
plan: LogicalPlan,
objType: DataType,
filters: ArrayBuffer[TypedFilter]): Array[TypedFilter] = plan match {
case t: TypedFilter if t.deserializer.dataType == objType =>
filters += t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we prepend rather than append found filters here? Otherwise filter predicates will be evaluated in reverse order after being combined. Also would be nice to comment about this.

collectTypedFiltersOnSameTypeObj(t.child, objType, filters)
case _ => filters.toArray
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.language.existentials

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._

object CatalystSerde {
Expand All @@ -45,13 +49,11 @@ object CatalystSerde {
*/
trait ObjectProducer extends LogicalPlan {
// The attribute that reference to the single object field this operator outputs.
protected def outputObjAttr: Attribute
def outputObjAttr: Attribute

override def output: Seq[Attribute] = outputObjAttr :: Nil

override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)

def outputObjectType: DataType = outputObjAttr.dataType
}

/**
Expand All @@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

def inputObjectType: DataType = child.output.head.dataType
def inputObjAttr: Attribute = child.output.head
}

/**
Expand Down Expand Up @@ -167,6 +169,43 @@ case class MapElements(
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer

object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
}
}

/**
* A relation produced by applying `func` to each element of the `child` and filter them by the
* resulting boolean value.
*
* This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
* the input row to object and apply the given function with decoded object. However we need the
* encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
* optimizer rules.
*/
case class TypedFilter(
func: AnyRef,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output

def withObject(obj: LogicalPlan): Filter = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about renaming this method to withObjectProducerChild?

assert(obj.output.length == 1)
Filter(getCondition(obj.output.head), obj)
}

def getCondition(input: Expression): Expression = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about renaming it to typedCondition?

val (funcClass, methodName) = func match {
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
case _ => classOf[Any => Boolean] -> "apply"
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Invoke(funcObj, methodName, BooleanType, input :: Nil)
}
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,111 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, ReferenceToExpressions}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.{BooleanType, ObjectType}

class TypedFilterOptimizationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("EliminateSerialization", FixedPoint(50),
EliminateSerialization) ::
Batch("EmbedSerializerInFilter", FixedPoint(50),
EmbedSerializerInFilter) :: Nil
Batch("CombineTypedFilters", FixedPoint(50),
CombineTypedFilters) :: Nil
}

implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()

test("back to back filter") {
test("filter after serialize") {
val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: (Int, Int)) => i._2 > 0
val f = (i: (Int, Int)) => i._1 > 0

val query = input.filter(f1).filter(f2).analyze
val query = input
.deserialize[(Int, Int)]
.serialize[(Int, Int)]
.filter(f).analyze

val optimized = Optimize.execute(query)

val expected = input.deserialize[(Int, Int)]
.where(callFunction(f1, BooleanType, 'obj))
.select('obj.as("obj"))
.where(callFunction(f2, BooleanType, 'obj))
val expected = input
.deserialize[(Int, Int)]
.where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze

comparePlans(optimized, expected)
}

// TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules
// for typed filters.
ignore("embed deserializer in typed filter condition if there is only one filter") {
test("filter after serialize with object change") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: OtherTuple) => i._1 > 0

val query = input
.deserialize[(Int, Int)]
.serialize[(Int, Int)]
.filter(f).analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, query)
}

test("filter before deserialize") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0

val query = input.filter(f).analyze
val query = input
.filter(f)
.deserialize[(Int, Int)]
.serialize[(Int, Int)].analyze

val optimized = Optimize.execute(query)

val expected = input
.deserialize[(Int, Int)]
.where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze

comparePlans(optimized, expected)
}

test("filter before deserialize with object change") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: OtherTuple) => i._1 > 0

val query = input
.filter(f)
.deserialize[(Int, Int)]
.serialize[(Int, Int)].analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, query)
}

test("back to back filter") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "back to back filters with the same object type"

val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: (Int, Int)) => i._2 > 0

val query = input.filter(f1).filter(f2).analyze

val optimized = Optimize.execute(query)

val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
val condition = callFunction(f, BooleanType, deserializer)
val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze
val boundReference = BoundReference(0, ObjectType(classOf[(Int, Int)]), nullable = false)
val callFunc1 = callFunction(f1, BooleanType, boundReference)
val callFunc2 = callFunction(f2, BooleanType, boundReference)
val condition = ReferenceToExpressions(callFunc2 && callFunc1, deserializer :: Nil)
val expected = input.where(condition).analyze

comparePlans(optimized, expected)
}

test("back to back filter with object change") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "back to back filters with different object types"

val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: OtherTuple) => i._2 > 0

val query = input.filter(f1).filter(f2).analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, query)
}
}
12 changes: 2 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1997,11 +1997,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: T => Boolean): Dataset[T] = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
val filter = Filter(condition, logicalPlan)
withTypedPlan(filter)
withTypedPlan(TypedFilter(func, logicalPlan))
}

/**
Expand All @@ -2014,11 +2010,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: FilterFunction[T]): Dataset[T] = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
val filter = Filter(condition, logicalPlan)
withTypedPlan(filter)
withTypedPlan(TypedFilter(func, logicalPlan))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case f: logical.TypedFilter =>
execution.FilterExec(f.getCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ abstract class QueryTest extends PlanTest {
case _: ObjectConsumer => return
case _: ObjectProducer => return
case _: AppendColumns => return
case _: TypedFilter => return
case _: LogicalRelation => return
case p if p.getClass.getSimpleName == "MetastoreRelation" => return
case _: MemoryPlan => return
Expand Down