-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan #20911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,42 +47,46 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| SimplifyExtractValueOps) :: Nil | ||
| } | ||
|
|
||
| val idAtt = ('id).long.notNull | ||
| val nullableIdAtt = ('nullable_id).long | ||
| private val idAtt = ('id).long.notNull | ||
| private val nullableIdAtt = ('nullable_id).long | ||
|
|
||
| lazy val relation = LocalRelation(idAtt, nullableIdAtt) | ||
| private val relation = LocalRelation(idAtt, nullableIdAtt) | ||
| private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) | ||
|
|
||
| private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { | ||
| val optimized = Optimizer.execute(originalQuery.analyze) | ||
| assert(optimized.resolved, "optimized plans must be still resolvable") | ||
| comparePlans(optimized, correctAnswer.analyze) | ||
| } | ||
|
|
||
| test("explicit get from namedStruct") { | ||
| val query = relation | ||
| .select( | ||
| GetStructField( | ||
| CreateNamedStruct(Seq("att", 'id )), | ||
| 0, | ||
| None) as "outerAtt").analyze | ||
| val expected = relation.select('id as "outerAtt").analyze | ||
| None) as "outerAtt") | ||
| val expected = relation.select('id as "outerAtt") | ||
|
|
||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("explicit get from named_struct- expression maintains original deduced alias") { | ||
| val query = relation | ||
| .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) | ||
| .analyze | ||
|
|
||
| val expected = relation | ||
| .select('id as "named_struct(att, id).att") | ||
| .analyze | ||
|
|
||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("collapsed getStructField ontop of namedStruct") { | ||
| val query = relation | ||
| .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") | ||
| .select(GetStructField('struct1, 0, None) as "struct1Att") | ||
| .analyze | ||
| val expected = relation.select('id as "struct1Att").analyze | ||
| comparePlans(Optimizer execute query, expected) | ||
| val expected = relation.select('id as "struct1Att") | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("collapse multiple CreateNamedStruct/GetStructField pairs") { | ||
|
|
@@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| .select( | ||
| GetStructField('struct1, 0, None) as "struct1Att1", | ||
| GetStructField('struct1, 1, None) as "struct1Att2") | ||
| .analyze | ||
|
|
||
| val expected = | ||
| relation. | ||
| select( | ||
| 'id as "struct1Att1", | ||
| ('id * 'id) as "struct1Att2") | ||
| .analyze | ||
|
|
||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("collapsed2 - deduced names") { | ||
|
|
@@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| .select( | ||
| GetStructField('struct1, 0, None), | ||
| GetStructField('struct1, 1, None)) | ||
| .analyze | ||
|
|
||
| val expected = | ||
| relation. | ||
| select( | ||
| 'id as "struct1.att1", | ||
| ('id * 'id) as "struct1.att2") | ||
| .analyze | ||
|
|
||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("simplified array ops") { | ||
|
|
@@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| 1, | ||
| false), | ||
| 1) as "a4") | ||
| .analyze | ||
|
|
||
| val expected = relation | ||
| .select( | ||
|
|
@@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| "att2", (('id + 1L) * ('id + 1L)))) as "a2", | ||
| ('id + 1L) as "a3", | ||
| ('id + 1L) as "a4") | ||
| .analyze | ||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("SPARK-22570: CreateArray should not create a lot of global variables") { | ||
|
|
@@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", | ||
| GetMapValue('m, "r32") as "a3", | ||
| GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") | ||
| .analyze | ||
|
|
||
| val expected = | ||
| relation.select( | ||
|
|
@@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ) | ||
| ) as "a3", | ||
| Literal.create(null, LongType) as "a4") | ||
| .analyze | ||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("simplify map ops, constant lookup, dynamic keys") { | ||
|
|
@@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))), | ||
| 13L) as "a") | ||
| .analyze | ||
|
|
||
| val expected = relation | ||
| .select( | ||
|
|
@@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| (EqualTo(13L, ('id + 1L)), ('id + 2L)), | ||
| (EqualTo(13L, ('id + 2L)), ('id + 3L)), | ||
| (Literal(true), 'id))) as "a") | ||
| .analyze | ||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { | ||
|
|
@@ -240,16 +234,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))), | ||
| ('id + 3L)) as "a") | ||
| .analyze | ||
| val expected = relation | ||
| .select( | ||
| CaseWhen(Seq( | ||
| (EqualTo('id + 3L, 'id), ('id + 1L)), | ||
| (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), | ||
| (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), | ||
| (Literal(true), ('id + 4L)))) as "a") | ||
| .analyze | ||
| comparePlans(Optimizer execute query, expected) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("simplify map ops, no positive match") { | ||
|
|
@@ -263,16 +255,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))), | ||
| 'id + 30L) as "a") | ||
| .analyze | ||
| val expected = relation.select( | ||
| CaseWhen(Seq( | ||
| (EqualTo('id + 30L, 'id), ('id + 1L)), | ||
| (EqualTo('id + 30L, ('id + 1L)), ('id + 2L)), | ||
| (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), | ||
| (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), | ||
| (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") | ||
| .analyze | ||
| comparePlans(Optimizer execute rel, expected) | ||
| checkRule(rel, expected) | ||
| } | ||
|
|
||
| test("simplify map ops, constant lookup, mixed keys, eliminated constants") { | ||
|
|
@@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))), | ||
| 13L) as "a") | ||
| .analyze | ||
|
|
||
| val expected = relation | ||
| .select( | ||
|
|
@@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 2L), ('id + 3L), | ||
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))) as "a") | ||
| .analyze | ||
|
|
||
| comparePlans(Optimizer execute rel, expected) | ||
| checkRule(rel, expected) | ||
| } | ||
|
|
||
| test("simplify map ops, potential dynamic match with null value + an absolute constant match") { | ||
|
|
@@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| ('id + 3L), ('id + 4L), | ||
| ('id + 4L), ('id + 5L))), | ||
| 2L ) as "a") | ||
| .analyze | ||
|
|
||
| val expected = relation | ||
| .select( | ||
|
|
@@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| // but it cannot override a potential match with ('id + 2L), | ||
| // which is exactly what [[Coalesce]] would do in this case. | ||
| (Literal.TrueLiteral, 'id))) as "a") | ||
| .analyze | ||
| comparePlans(Optimizer execute rel, expected) | ||
| checkRule(rel, expected) | ||
| } | ||
|
|
||
| test("SPARK-23500: Simplify array ops that are not at the top node") { | ||
| val query = LocalRelation('id.long) | ||
| .select( | ||
| CreateArray(Seq( | ||
| CreateNamedStruct(Seq( | ||
| "att1", 'id, | ||
| "att2", 'id * 'id)), | ||
| CreateNamedStruct(Seq( | ||
| "att1", 'id + 1, | ||
| "att2", ('id + 1) * ('id + 1)) | ||
| )) | ||
| ) as "arr") | ||
| .select( | ||
| GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", | ||
| GetArrayItem( | ||
| GetArrayStructFields('arr, | ||
| StructField("att1", LongType, nullable = false), | ||
| ordinal = 0, | ||
| numFields = 1, | ||
| containsNull = false), | ||
| ordinal = 1) as "a2") | ||
| .orderBy('id.asc) | ||
|
|
||
| val expected = LocalRelation('id.long) | ||
| .select( | ||
| ('id + 1L) as "a1", | ||
| ('id + 1L) as "a2") | ||
| .orderBy('id.asc) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("SPARK-23500: Simplify map ops that are not top nodes") { | ||
| val query = | ||
| LocalRelation('id.long) | ||
| .select( | ||
| CreateMap(Seq( | ||
| "r1", 'id, | ||
| "r2", 'id + 1L)) as "m") | ||
| .select( | ||
| GetMapValue('m, "r1") as "a1", | ||
| GetMapValue('m, "r32") as "a2") | ||
| .where('id > 0L) | ||
| .select('a1, 'a2) | ||
|
|
||
| val expected = | ||
| LocalRelation('id.long).select( | ||
| 'id as "a1", | ||
| Literal.create(null, LongType) as "a2") | ||
| .where('id > 0L) | ||
| checkRule(query, expected) | ||
| } | ||
|
|
||
| test("SPARK-23500: Simplify complex ops that aren't at the plan root") { | ||
| val structRel = relation | ||
| .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") | ||
| .groupBy($"foo")("1").analyze | ||
| .groupBy($"foo")("1") | ||
| val structExpected = relation | ||
| .select('nullable_id as "foo") | ||
| .groupBy($"foo")("1").analyze | ||
| comparePlans(Optimizer execute structRel, structExpected) | ||
| .groupBy($"foo")("1") | ||
| checkRule(structRel, structExpected) | ||
|
|
||
| // These tests must use nullable attributes from the base relation for the following reason: | ||
| // in the 'original' plans below, the Aggregate node produced by groupBy() has a | ||
|
|
@@ -351,29 +389,63 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { | |
| // SPARK-23634. | ||
| val arrayRel = relation | ||
| .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") | ||
| .groupBy($"a1")("1").analyze | ||
| val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze | ||
| comparePlans(Optimizer execute arrayRel, arrayExpected) | ||
| .groupBy($"a1")("1") | ||
| val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") | ||
| checkRule(arrayRel, arrayExpected) | ||
|
|
||
| val mapRel = relation | ||
| .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") | ||
| .groupBy($"m1")("1").analyze | ||
| .groupBy($"m1")("1") | ||
| val mapExpected = relation | ||
| .select('nullable_id as "m1") | ||
| .groupBy($"m1")("1").analyze | ||
| comparePlans(Optimizer execute mapRel, mapExpected) | ||
| .groupBy($"m1")("1") | ||
| checkRule(mapRel, mapExpected) | ||
| } | ||
|
|
||
| test("SPARK-23500: Ensure that aggregation expressions are not simplified") { | ||
| // Make sure that aggregation exprs are correctly ignored. Maps can't be used in | ||
| // grouping exprs so aren't tested here. | ||
| val structAggRel = relation.groupBy( | ||
| CreateNamedStruct(Seq("att1", 'nullable_id)))( | ||
| GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze | ||
| comparePlans(Optimizer execute structAggRel, structAggRel) | ||
| GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) | ||
| checkRule(structAggRel, structAggRel) | ||
|
|
||
| val arrayAggRel = relation.groupBy( | ||
| CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze | ||
| comparePlans(Optimizer execute arrayAggRel, arrayAggRel) | ||
| CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) | ||
| checkRule(arrayAggRel, arrayAggRel) | ||
|
|
||
| // This could be done if we had a more complex rule that checks that | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we merge it with the previous test case? |
||
| // the CreateMap does not come from key. | ||
| val originalQuery = relation | ||
| .groupBy('id)( | ||
| GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" | ||
| ) | ||
| checkRule(originalQuery, originalQuery) | ||
| } | ||
|
|
||
| test("SPARK-23500: namedStruct and getField in the same Project #1") { | ||
| val originalQuery = | ||
| testRelation | ||
| .select( | ||
| namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) | ||
| .select('s1 getField "col2" as 's1Col2, | ||
| namedStruct("col1", 'a, "col2", 'b).as("s2")) | ||
| .select('s1Col2, 's2 getField "col2" as 's2Col2) | ||
| val correctAnswer = | ||
| testRelation | ||
| .select('c as 's1Col2, 'b as 's2Col2) | ||
| checkRule(originalQuery, correctAnswer) | ||
| } | ||
|
|
||
| test("SPARK-23500: namedStruct and getField in the same Project #2") { | ||
| val originalQuery = | ||
| testRelation | ||
| .select( | ||
| namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, | ||
| namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) | ||
| val correctAnswer = | ||
| testRelation | ||
| .select('c as 'sCol2, 'a as 'sCol1) | ||
| checkRule(originalQuery, correctAnswer) | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use sort here too?