Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1178,16 +1178,24 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] {
e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate

case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
// Exclude child's constraints to guarantee idempotency
val inferredFilters = ExpressionSet(
Seq(
GreaterThan(Size(g.children.head), Literal(0)),
IsNotNull(g.children.head)
)
) -- generate.child.constraints

if (inferredFilters.nonEmpty) {
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
val input = g.children.head
// Generating extra predicates here has overheads/risks:
// - We may evaluate expensive input expressions multiple times.
// - We may infer too many constraints later.
// - The input expression may fail to be evaluated under ANSI mode. If we reorder the
// predicates and evaluate the input expression first, we may fail the query unexpectedly.
// To be safe, here we only generate extra predicates if the input is an attribute.
if (input.isInstanceOf[Attribute]) {
Copy link
Contributor Author

@cloud-fan cloud-fan Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's almost useless to generate predicate with CreateArray/CreateMap. Size(CreateArray(...)) > 0 is always true unless you create an empty array.

// Exclude child's constraints to guarantee idempotency
val inferredFilters = ExpressionSet(
Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
) -- generate.child.constraints

if (inferredFilters.nonEmpty) {
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
} else {
generate
}
} else {
generate
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
private def hasNoSideEffect(e: Expression): Boolean = e match {
case _: Attribute => true
case _: Literal => true
case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this change or the change in InferFiltersFromGenerate can fix the perf issue. But I keep both fixes to be super safe.

case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val testRelation = LocalRelation('a.array(StructType(Seq(
StructField("x", IntegerType),
StructField("y", IntegerType)
))), 'c1.string, 'c2.string)
))), 'c1.string, 'c2.string, 'c3.int)

Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
val generator = f('a)
Expand Down Expand Up @@ -74,6 +74,13 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}

val fromJson = f(JsonToStructs(ArrayType(new StructType().add("s", "string")), Map.empty, 'c1))
test("SPARK-37392: Don't infer filters from " + fromJson) {
val originalQuery = testRelation.generate(fromJson).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

// setup rules to test inferFilters with ConstantFolding to make sure
Expand All @@ -91,28 +98,28 @@ class InferFiltersFromGenerateSuite extends PlanTest {
}

Seq(Explode(_), PosExplode(_)).foreach { f =>
val createArrayExplode = f(CreateArray(Seq('c1)))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
val originalQuery = testRelation.generate(createArrayExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
val originalQuery = testRelation.generate(createMapExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}
val createArrayExplode = f(CreateArray(Seq('c1)))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
val originalQuery = testRelation.generate(createArrayExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
val originalQuery = testRelation.generate(createMapExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

Seq(Inline(_)).foreach { f =>
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}
Seq(Inline(_)).foreach { f =>
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

test("SPARK-36715: Don't infer filters from udf") {
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
Expand Down